Example #1
0
    def __init__(
        self,
        worker_id,
        address,
        port,
        worker_type="RAW",
        result_size_limit=DEFAULT_RESULT_SIZE_LIMIT_B,
    ):

        self.worker_id = worker_id
        self.address = address
        self.port = port
        self.worker_type = worker_type
        self.serializer = FuncXSerializer()
        self.serialize = self.serializer.serialize
        self.deserialize = self.serializer.deserialize
        self.result_size_limit = result_size_limit

        log.info(f"Initializing worker {worker_id}")
        log.info(f"Worker is of type: {worker_type}")

        self.context = zmq.Context()
        self.poller = zmq.Poller()
        self.identity = worker_id.encode()

        self.task_socket = self.context.socket(zmq.DEALER)
        self.task_socket.setsockopt(zmq.IDENTITY, self.identity)

        log.info(f"Trying to connect to : tcp://{self.address}:{self.port}")
        self.task_socket.connect(f"tcp://{self.address}:{self.port}")
        self.poller.register(self.task_socket, zmq.POLLIN)
        signal.signal(signal.SIGTERM, self.handler)
Example #2
0
def serialize_fx_inputs(*args, **kwargs):
    from funcx.serialize import FuncXSerializer
    fx_serializer = FuncXSerializer()
    ser_args = fx_serializer.serialize(args)
    ser_kwargs = fx_serializer.serialize(kwargs)
    payload = fx_serializer.pack_buffers([ser_args, ser_kwargs])
    return payload
Example #3
0
def serialize_fx_inputs(*args, **kwargs):
    """Pack and serialize inputs
    """
    fx_serializer = FuncXSerializer()
    ser_args = fx_serializer.serialize(args)
    ser_kwargs = fx_serializer.serialize(kwargs)
    payload = fx_serializer.pack_buffers([ser_args, ser_kwargs])
    return payload
Example #4
0
 def get_funcx_function_checksum(funcx_function):
     """
     Get the SHA256 checksum of a funcx function
     :returns sha256 hex string of a given funcx function
     """
     fxs = FuncXSerializer()
     serialized_func = fxs.serialize(funcx_function).encode()
     return hashlib.sha256(serialized_func).hexdigest()
Example #5
0
class Batch:
    """Utility class for creating batch submission in funcX"""
    def __init__(self):
        self.tasks = []
        self.fx_serializer = FuncXSerializer()

    def add(self, *args, endpoint_id=None, function_id=None, **kwargs):
        """Add an function invocation to a batch submission

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        None
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_args = self.fx_serializer.serialize(args)
        ser_kwargs = self.fx_serializer.serialize(kwargs)
        payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])

        data = {
            'endpoint': endpoint_id,
            'function': function_id,
            'payload': payload
        }

        self.tasks.append(data)

    def prepare(self):
        """Prepare the payloads to be post to web service in a batch

        Parameters
        ----------

        Returns
        -------
        payloads in dictionary, Dict[str, list]
        """
        data = {'tasks': []}

        for task in self.tasks:
            new_task = (task['function'], task['endpoint'], task['payload'])
            data['tasks'].append(new_task)

        return data
Example #6
0
    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://dev.funcx.org/api/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        service_address: str
        The address of the funcX web service to communicate with.
        Default: https://dev.funcx.org/api/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        native_client = NativeClient(client_id=self.CLIENT_ID)

        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"

        if not fx_authorizer:
            native_client.login(
                requested_scopes=[fx_scope],
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = native_client.get_authorizers_by_scope(
                requested_scopes=[fx_scope])
            fx_authorizer = all_authorizers[fx_scope]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()
Example #7
0
    def __init__(self, task_group_id=None):
        """
        Parameters
        ==========

        task_group_id : str
            UUID indicating the task group that this batch belongs to
        """
        self.tasks = []
        self.fx_serializer = FuncXSerializer()
        self.task_group_id = task_group_id
Example #8
0
def deserialize():
    """Return the deserialized result
    """

    fx_serializer = FuncXSerializer()
    # Return a failure message if all else fails
    ret_package = {'error': 'Failed to deserialize result'}
    try:
        inputs = request.json
        res = fx_serializer.deserialize(inputs)
        ret_package = jsonify(res)
    except Exception as e:
        print(e)
        return jsonify(ret_package), 500
    return ret_package, 200
Example #9
0
    def __init__(self, gsearchresult):
        """

        Parameters
        ----------
        gsearchresult : dict
        """
        # wrapper for an array of results
        results = gsearchresult['results']
        super().__init__(results)

        # track data about where we are in total results
        self.has_next_page = gsearchresult['has_next_page']
        self.offset = gsearchresult['offset']
        self.total = gsearchresult['total']

        # we can use this to load functions and run them
        self.serializer = FuncXSerializer()

        # Reformat for pretty printing and easy viewing
        self._init_columns()
        self.table = Texttable(max_width=120)
        self.table.header(self.columns)
        for res in self:
            self.table.add_row([
                res[col] for col in self.columns
            ])
Example #10
0
def test(endpoint_id=None, tasks=10, hostname=None, port=None):
    tasks_rq = RedisQueue(f'task_{endpoint_id}', hostname)
    results_rq = RedisQueue(f'results', hostname)
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    while True:
        try:
            x = results_rq.get(timeout=1)
        except:
            print("No more results left")
            break

    tasks_rq.connect()
    results_rq.connect()
    start = time.time()
    for i in range(tasks):
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({'duration':0})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        tasks_rq.put(f"0{i}", payload)

    for i in range(tasks):
        res = results_rq.get(timeout=1)
        print("Result : ", res)

    delta = time.time() - start
    print("Time to complete {} tasks: {:8.3f} s".format(tasks, delta))
    print("Throughput : {:8.3f} Tasks/s".format(tasks / delta))
    return delta
def dont_run_yet(endpoint_id=None, tasks=10, duration=1, hostname=None):
    # tasks_rq = EndpointQueue(f'task_{endpoint_id}', hostname)
    tasks_channel = RedisPubSub(hostname)
    tasks_channel.connect()
    redis_client = tasks_channel.redis_client
    redis_client.ping()
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    start = time.time()
    task_ids = {}
    for i in range(tasks):
        time.sleep(duration)
        task_id = str(uuid.uuid4())
        print("Task_id : ", task_id)
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({"duration": duration})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        container_id = "RAW"
        task = Task(redis_client, task_id, container_id, serializer="", payload=payload)
        task.endpoint = endpoint_id
        task.status = TaskState.WAITING_FOR_EP
        # tasks_rq.enqueue(task)
        tasks_channel.put(endpoint_id, task)
        task_ids[i] = task_id

    d1 = time.time() - start
    print(f"Time to launch {tasks} tasks: {d1:8.3f} s")

    delay = 5
    print(f"Sleeping {delay} seconds")
    time.sleep(delay)
    print(f"Launched {tasks} tasks")
    for i in range(tasks):
        task_id = task_ids[i]
        print("Task_id : ", task_id)
        task = Task.from_id(redis_client, task_id)
        # TODO: wait for task result...
        time.sleep(duration)
        try:
            result = fxs.deserialize(task.result)
            print(f"Result : {result}")
        except Exception as e:
            print(f"Task failed with exception:{e}")
            pass

    delta = time.time() - start
    print(f"Time to complete {tasks} tasks: {delta:8.3f} s")
    print(f"Throughput : {tasks / delta:8.3f} Tasks/s")
    return delta
Example #12
0
    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        Keyword arguments are the same as for BaseClient.
        """
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        if force_login or not fx_authorizer:
            fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
            auth_res = login(services=[fx_scope],
                             app_name="funcX_Client",
                             client_id=self.CLIENT_ID,
                             clear_old_tokens=force_login,
                             token_dir=self.TOKEN_DIR)
            dlh_authorizer = auth_res['funcx_service']

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=dlh_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=self.FUNCX_SERVICE_ADDRESS,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()
Example #13
0
    def __init__(self,
                 worker_id,
                 address,
                 port,
                 logdir,
                 debug=False,
                 worker_type='RAW'):

        self.worker_id = worker_id
        self.address = address
        self.port = port
        self.logdir = logdir
        self.debug = debug
        self.worker_type = worker_type
        self.serializer = FuncXSerializer()
        self.serialize = self.serializer.serialize
        self.deserialize = self.serializer.deserialize

        global logger
        logger = set_file_logger(
            '{}/funcx_worker_{}.log'.format(logdir, worker_id),
            name="worker_log",
            level=logging.DEBUG if debug else logging.INFO)

        logger.info('Initializing worker {}'.format(worker_id))
        logger.info('Worker is of type: {}'.format(worker_type))

        if debug:
            logger.debug('Debug logging enabled')

        self.context = zmq.Context()
        self.poller = zmq.Poller()
        self.identity = worker_id.encode()

        self.task_socket = self.context.socket(zmq.DEALER)
        self.task_socket.setsockopt(zmq.IDENTITY, self.identity)

        logger.info('Trying to connect to : tcp://{}:{}'.format(
            self.address, self.port))
        self.task_socket.connect('tcp://{}:{}'.format(self.address, self.port))
        self.poller.register(self.task_socket, zmq.POLLIN)
Example #14
0
    def __init__(self, dlh_authorizer=None, search_client=None, http_timeout=None,
                 force_login=False, fx_authorizer=None, **kwargs):
        """Initialize the client

        Args:
            dlh_authorizer (:class:`GlobusAuthorizer
                            <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with DLHub.
                If ``None``, will be created.
            search_client (:class:`SearchClient <globus_sdk.SearchClient>`):
                An authenticated SearchClient to communicate with Globus Search.
                If ``None``, will be created.
            http_timeout (int): Timeout for any call to service in seconds. (default is no timeout)
            force_login (bool): Whether to force a login to get new credentials.
                A login will always occur if ``dlh_authorizer`` or ``search_client``
                are not provided.
            no_local_server (bool): Disable spinning up a local server to automatically
                copy-paste the auth code. THIS IS REQUIRED if you are on a remote server.
                When used locally with no_local_server=False, the domain is localhost with
                a randomly chosen open port number.
                **Default**: ``True``.
            fx_authorizer (:class:`GlobusAuthorizer
                            <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with funcX.
                If ``None``, will be created.
            no_browser (bool): Do not automatically open the browser for the Globus Auth URL.
                Display the URL instead and let the user navigate to that location manually.
                **Default**: ``True``.
        Keyword arguments are the same as for BaseClient.
        """
        if force_login or not dlh_authorizer or not search_client or not fx_authorizer:

            fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
            auth_res = login(services=["search", "dlhub",
                                       fx_scope],
                             app_name="DLHub_Client",
                             client_id=CLIENT_ID, clear_old_tokens=force_login,
                             token_dir=_token_dir, no_local_server=kwargs.get("no_local_server", True),
                             no_browser=kwargs.get("no_browser", True))
            dlh_authorizer = auth_res["dlhub"]
            fx_authorizer = auth_res[fx_scope]
            self._search_client = auth_res["search"]
            self._fx_client = FuncXClient(force_login=True,fx_authorizer=fx_authorizer,
                                          funcx_service_address='https://funcx.org/api/v1')

        # funcX endpoint to use
        self.fx_endpoint = '86a47061-f3d9-44f0-90dc-56ddc642c000'
        # self.fx_endpoint = '2c92a06a-015d-4bfa-924c-b3d0c36bdad7'
        self.fx_serializer = FuncXSerializer()
        self.fx_cache = {}
        super(DLHubClient, self).__init__("DLHub", environment='dlhub', authorizer=dlh_authorizer,
                                          http_timeout=http_timeout, base_url=DLHUB_SERVICE_ADDRESS,
                                          **kwargs)
Example #15
0
def server(port=0, host="", debug=False, datasize=102400):

    try:
        from funcx.serialize import FuncXSerializer

        fxs = FuncXSerializer(use_offprocess_checker=False)
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind((host, port))
            bound_port = s.getsockname()[1]
            print(f"BINDING TO:{bound_port}", flush=True)
            s.listen(1)
            conn, addr = s.accept()  # we only expect one incoming connection here.
            with conn:
                while True:

                    b_msg = conn.recv(datasize)
                    if not b_msg:
                        print("Exiting")
                        return

                    msg = pickle.loads(b_msg)

                    if msg == "PING":
                        ret_value = ("PONG", None)
                    else:
                        try:
                            method = fxs.deserialize(msg)  # noqa
                            del method
                        except Exception as e:
                            ret_value = ("DESERIALIZE_FAIL", str(e))

                        else:
                            ret_value = ("SUCCESS", None)

                    ret_buf = pickle.dumps(ret_value)
                    conn.sendall(ret_buf)
    except Exception as e:
        print(f"OFF_PROCESS_CHECKER FAILURE, Exception:{e}")
        sys.exit()
Example #16
0
class FuncXFuture(Future):
    client = FuncXClient()
    serializer = FuncXSerializer()

    def __init__(self, task_id, poll_period=1):
        super().__init__()
        self.task_id = task_id
        self.poll_period = poll_period
        self.__result = None
        self.submitted = time.time()

    def done(self):
        if self.__result is not None:
            return True
        try:
            data = FuncXFuture.client.get_task_status(self.task_id)
        except Exception:
            return False
        if 'status' in data and data['status'] == 'PENDING':
            time.sleep(
                self.poll_period)  # needed to not overwhelm the FuncX server
            return False
        elif 'result' in data:
            self.__result = FuncXFuture.serializer.deserialize(data['result'])
            self.returned = time.time()
            # FIXME AW benchmarking
            self.connected_managers = os.environ.get('connected_managers', -1)

            return True
        elif 'exception' in data:
            e = FuncXFuture.serializer.deserialize(data['exception'])
            e.reraise()
        else:
            raise NotImplementedError(
                'task {} is neither pending or finished: {}'.format(
                    self.task_id, str(data)))

    def result(self, timeout=None):
        if self.__result is not None:
            return self.__result
        while True:
            if self.done():
                break
            else:
                time.sleep(self.poll_period)
                if timeout is not None:
                    timeout -= self.poll_period
                    if timeout < 0:
                        raise TimeoutError

        return self.__result
Example #17
0
def dont_run_yet(endpoint_id=None, tasks=10, duration=1, hostname=None):
    tasks_rq = EndpointQueue(f"task_{endpoint_id}", hostname)
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    tasks_rq.connect()
    start = time.time()
    task_ids = {}
    for i in range(tasks):
        task_id = str(uuid.uuid4())
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({"duration": duration})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        container_id = "RAW"
        task = Task(tasks_rq.redis_client,
                    task_id,
                    container_id,
                    serializer="",
                    payload=payload)
        tasks_rq.enqueue(task)
        task_ids[i] = task_id

    d1 = time.time() - start
    print(f"Time to launch {tasks} tasks: {d1:8.3f} s")

    print(f"Launched {tasks} tasks")
    for i in range(tasks):
        task_id = task_ids[i]
        task = Task.from_id(tasks_rq.redis_client, task_id)
        # TODO: wait for task result...
        time.sleep(2)
        print(f"Result: {task.result}")
        # res = results_rq.get('result', timeout=300)
        # print("Result : ", res)

    delta = time.time() - start
    print(f"Time to complete {tasks} tasks: {delta:8.3f} s")
    print(f"Throughput : {tasks / delta:8.3f} Tasks/s")
    return delta
Example #18
0
def test(endpoint_id=None, tasks=10, duration=1, hostname=None, port=None):
    tasks_rq = RedisQueue(f'task_{endpoint_id}', hostname)
    results_rq = RedisQueue('results', hostname)
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    tasks_rq.connect()
    results_rq.connect()

    while True:
        try:
            _ = results_rq.get(timeout=1)
        except Exception:
            print("No more results left")
            break

    start = time.time()
    for i in range(tasks):
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({'duration': duration})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        container_id = "odd" if i % 2 else "even"
        tasks_rq.put(f"0{i};{container_id}", payload)

    d1 = time.time() - start
    print("Time to launch {} tasks: {:8.3f} s".format(tasks, d1))

    print(f"Launched {tasks} tasks")
    for i in range(tasks):
        _ = results_rq.get(timeout=300)
        # print("Result : ", res)

    delta = time.time() - start
    print("Time to complete {} tasks: {:8.3f} s".format(tasks, delta))
    print("Throughput : {:8.3f} Tasks/s".format(tasks / delta))
    return delta
Example #19
0
    def __init__(
            self,
            task_q_url="tcp://127.0.0.1:50097",
            result_q_url="tcp://127.0.0.1:50098",
            max_queue_size=10,
            cores_per_worker=1,
            max_workers=float('inf'),
            uid=None,
            heartbeat_threshold=120,
            heartbeat_period=30,
            logdir=None,
            debug=False,
            block_id=None,
            internal_worker_port_range=(50000, 60000),
            mode="singularity_reuse",
            container_image=None,
            # TODO : This should be 10ms
            poll_period=100):
        """
        Parameters
        ----------
        worker_url : str
             Worker url on which workers will attempt to connect back

        uid : str
             string unique identifier

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0. Default=1

        max_workers : int
             caps the maximum number of workers that can be launched.
             default: infinity

        heartbeat_threshold : int
             Seconds since the last message from the interchange after which the
             interchange is assumed to be un-available, and the manager initiates shutdown. Default:120s

             Number of seconds since the last message from the interchange after which the worker
             assumes that the interchange is lost and the manager shuts down. Default:120

        heartbeat_period : int
             Number of seconds after which a heartbeat message is sent to the interchange

        internal_worker_port_range : tuple(int, int)
             Port range from which the port(s) for the workers to connect to the manager is picked.
             Default: (50000,60000)

        mode : str
             Pick between 3 supported modes for the worker:
              1. no_container : Worker launched without containers
              2. singularity_reuse : Worker launched inside a singularity container that will be reused
              3. singularity_single_use : Each worker and task runs inside a new container instance.

        container_image : str
             Path or identifier for the container to be used. Default: None

        poll_period : int
             Timeout period used by the manager in milliseconds. Default: 10ms
        """

        logger.info("Manager started")

        self.context = zmq.Context()
        self.task_incoming = self.context.socket(zmq.DEALER)
        self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        # Linger is set to 0, so that the manager can exit even when there might be
        # messages in the pipe
        self.task_incoming.setsockopt(zmq.LINGER, 0)
        self.task_incoming.connect(task_q_url)

        self.logdir = logdir
        self.debug = debug
        self.block_id = block_id
        self.result_outgoing = self.context.socket(zmq.DEALER)
        self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        self.result_outgoing.setsockopt(zmq.LINGER, 0)
        self.result_outgoing.connect(result_q_url)
        logger.info("Manager connected")

        self.uid = uid

        self.mode = mode
        self.container_image = container_image
        self.cores_on_node = multiprocessing.cpu_count()
        self.max_workers = max_workers
        self.cores_per_workers = cores_per_worker
        self.available_mem_on_node = round(
            psutil.virtual_memory().available / (2**30), 1)
        self.worker_count = min(
            max_workers, math.floor(self.cores_on_node / cores_per_worker))
        self.worker_map = WorkerMap(self.worker_count)

        self.internal_worker_port_range = internal_worker_port_range

        self.funcx_task_socket = self.context.socket(zmq.ROUTER)
        self.funcx_task_socket.set_hwm(0)
        self.address = '127.0.0.1'
        self.worker_port = self.funcx_task_socket.bind_to_random_port(
            "tcp://*",
            min_port=self.internal_worker_port_range[0],
            max_port=self.internal_worker_port_range[1])

        logger.info(
            "Manager listening on {} port for incoming worker connections".
            format(self.worker_port))

        self.task_queues = {'RAW': queue.Queue()}

        self.pending_result_queue = multiprocessing.Queue()

        self.max_queue_size = max_queue_size + self.worker_count
        self.tasks_per_round = 1

        self.heartbeat_period = heartbeat_period
        self.heartbeat_threshold = heartbeat_threshold
        self.poll_period = poll_period
        self.serializer = FuncXSerializer()
        self.next_worker_q = []  # FIFO queue for spinning up workers.
Example #20
0
import json
import sys
import argparse
import time
import funcx
from funcx import FuncXClient
from funcx.serialize import FuncXSerializer
fxs = FuncXSerializer()

# funcx.set_stream_logger()


def double(x):
    return x * 2


def test(fxc, ep_id, task_count=10):

    fn_uuid = fxc.register_function(double, description="Yadu double")
    print("FN_UUID : ", fn_uuid)

    start = time.time()
    task_ids = fxc.map_run(list(range(task_count)),
                           endpoint_id=ep_id,
                           function_id=fn_uuid)
    delta = time.time() - start
    print("Time to launch {} tasks: {:8.3f} s".format(task_count, delta))
    print("Got {} tasks_ids ".format(len(task_ids)))

    for i in range(3):
        x = fxc.get_batch_status(task_ids)
Example #21
0
    def __init__(self,
                 task_q,
                 result_q,
                 executor,
                 endpoint_id,
                 heartbeat_threshold=60,
                 endpoint_addr=None,
                 redis_address=None,
                 logdir="forwarder_logs",
                 logging_level=logging.INFO,
                 max_heartbeats_missed=2):
        """
        Parameters
        ----------
        task_q : A queue object
        Any queue object that has get primitives. This must be a thread-safe queue.

        result_q : A queue object
        Any queue object that has put primitives. This must be a thread-safe queue.

        executor: Executor object
        Executor to which tasks are to be forwarded

        endpoint_id: str
        Usually a uuid4 as string that identifies the executor

        endpoint_addr: str
        Endpoint ip address as a string

        heartbeat_threshold : int
        Heartbeat threshold in seconds

        logdir: str
        Path to logdir

        logging_level : int
        Logging level as defined in the logging module. Default: logging.INFO (20)

        max_heartbeats_missed : int
        The maximum heartbeats missed before the forwarder terminates

        """
        super().__init__()
        self.logdir = logdir
        os.makedirs(self.logdir, exist_ok=True)

        global logger
        logger = logging.getLogger(endpoint_id)

        if len(logger.handlers) == 0:
            logger = set_file_logger(os.path.join(
                self.logdir, "forwarder.{}.log".format(endpoint_id)),
                                     name=endpoint_id,
                                     level=logging_level)

        logger.info(
            "Initializing forwarder for endpoint:{}".format(endpoint_id))
        logger.info("Log level set to {}".format(loglevels[logging_level]))

        self.endpoint_addr = endpoint_addr
        self.task_q = task_q
        self.result_q = result_q
        self.heartbeat_threshold = heartbeat_threshold
        self.executor = executor
        self.endpoint_id = endpoint_id
        self.endpoint_addr = endpoint_addr
        self.redis_address = redis_address
        self.internal_q = Queue()
        self.client_ports = None
        self.fx_serializer = FuncXSerializer()
        self.kill_event = threading.Event()
        self.max_heartbeats_missed = max_heartbeats_missed
Example #22
0
except Exception as e:
    print(e)
print("Started the endpoint {}".format(endpoint_name))
print("Wating 10 seconds for the endpoint to start")
time.sleep(10)

# Connect to the task and result redis queue
endpoint_id = args.endpoint_id
tasks_rq = RedisQueue(f'task_{endpoint_id}', args.redis_hostname)
results_rq = RedisQueue(f'results', args.redis_hostname)
tasks_rq.connect()
results_rq.connect()
print("Redis queue connected")

# Create an instance of funcx serializer and serialize the function
fxs = FuncXSerializer()
ser_code = fxs.serialize(sleep)
fn_code = fxs.pack_buffers([ser_code])
print("Code serialized")

# Define the test function
def test(tasks=10, durs=[5, 10, 20], timeout=None):
    # Make sure there is no previous result left
    while True:
        try:
            x = results_rq.get(timeout=1)
        except:
            print("No more results left")
            break
    start_submit = time.time()
    time_table = {}
Example #23
0
    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://api.funcx.org/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        funcx_service_address: str
        The address of the funcX web service to communicate with.
        Default: https://api.funcx.org/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.func_table = {}
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        if not os.path.exists(self.TOKEN_DIR):
            os.makedirs(self.TOKEN_DIR)

        tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME)
        self.native_client = NativeClient(
            client_id=self.CLIENT_ID,
            app_name="FuncX SDK",
            token_storage=JSONTokenStorage(tokens_filename))

        # TODO: if fx_authorizer is given, we still need to get an authorizer for Search
        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
        search_scope = "urn:globus:auth:scope:search.api.globus.org:all"
        scopes = [fx_scope, search_scope, "openid"]

        search_authorizer = None

        if not fx_authorizer:
            self.native_client.login(
                requested_scopes=scopes,
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = self.native_client.get_authorizers_by_scope(
                requested_scopes=scopes)
            fx_authorizer = all_authorizers[fx_scope]
            search_authorizer = all_authorizers[search_scope]
            openid_authorizer = all_authorizers["openid"]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

        authclient = AuthClient(authorizer=openid_authorizer)
        user_info = authclient.oauth2_userinfo()
        self.searcher = SearchHelper(authorizer=search_authorizer,
                                     owner_uuid=user_info['sub'])
        self.funcx_service_address = funcx_service_address
Example #24
0
class FuncXClient(throttling.ThrottledBaseClient):
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    TOKEN_FILENAME = 'funcx_sdk_tokens.json'
    CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c'

    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://api.funcx.org/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        funcx_service_address: str
        The address of the funcX web service to communicate with.
        Default: https://api.funcx.org/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.func_table = {}
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        if not os.path.exists(self.TOKEN_DIR):
            os.makedirs(self.TOKEN_DIR)

        tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME)
        self.native_client = NativeClient(
            client_id=self.CLIENT_ID,
            app_name="FuncX SDK",
            token_storage=JSONTokenStorage(tokens_filename))

        # TODO: if fx_authorizer is given, we still need to get an authorizer for Search
        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
        search_scope = "urn:globus:auth:scope:search.api.globus.org:all"
        scopes = [fx_scope, search_scope, "openid"]

        search_authorizer = None

        if not fx_authorizer:
            self.native_client.login(
                requested_scopes=scopes,
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = self.native_client.get_authorizers_by_scope(
                requested_scopes=scopes)
            fx_authorizer = all_authorizers[fx_scope]
            search_authorizer = all_authorizers[search_scope]
            openid_authorizer = all_authorizers["openid"]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

        authclient = AuthClient(authorizer=openid_authorizer)
        user_info = authclient.oauth2_userinfo()
        self.searcher = SearchHelper(authorizer=search_authorizer,
                                     owner_uuid=user_info['sub'])
        self.funcx_service_address = funcx_service_address

    def version_check(self):
        """Check this client version meets the service's minimum supported version.
        """
        resp = self.get("version", params={"service": "all"})
        versions = resp.data
        if "min_ep_version" not in versions:
            raise VersionMismatch(
                "Failed to retrieve version information from funcX service.")

        min_ep_version = versions['min_ep_version']

        if ENDPOINT_VERSION is None:
            raise VersionMismatch(
                "You do not have the funcx endpoint installed.  You can use 'pip install funcx-endpoint'."
            )
        if ENDPOINT_VERSION < min_ep_version:
            raise VersionMismatch(
                f"Your version={ENDPOINT_VERSION} is lower than the "
                f"minimum version for an endpoint: {min_ep_version}.  Please update."
            )

    def logout(self):
        """Remove credentials from your local system
        """
        self.native_client.logout()

    def update_table(self, return_msg, task_id):
        """ Parses the return message from the service and updates the internal func_tables

        Parameters
        ----------

        return_msg : str
           Return message received from the funcx service
        task_id : str
           task id string
        """
        if isinstance(return_msg, str):
            r_dict = json.loads(return_msg)
        else:
            r_dict = return_msg

        status = {'pending': True}

        if 'result' in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict['result'])
                completion_t = r_dict['completion_t']
            except Exception:
                raise SerializationError("Result Object Deserialization")
            else:
                status.update({
                    'pending': False,
                    'result': r_obj,
                    'completion_t': completion_t
                })
                self.func_table[task_id] = status

        elif 'exception' in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(
                    r_dict['exception'])
                completion_t = r_dict['completion_t']
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise SerializationError(
                    "Task's exception object deserialization")
            else:
                status.update({
                    'pending': False,
                    'exception': r_exception,
                    'completion_t': completion_t
                })
                self.func_table[task_id] = status
        return status

    def get_task(self, task_id):
        """Get a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Task block containing "status" key.
        """
        if task_id in self.func_table:
            return self.func_table[task_id]

        r = self.get("tasks/{task_id}".format(task_id=task_id))
        logger.debug("Response string : {}".format(r))
        try:
            rets = self.update_table(r.text, task_id)
        except Exception as e:
            raise e
        return rets

    def get_result(self, task_id):
        """ Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """
        task = self.get_task(task_id)
        if task['pending'] is True:
            raise Exception("Task pending")
        else:
            if 'result' in task:
                return task['result']
            else:
                logger.warning("We have an exception : {}".format(
                    task['exception']))
                task['exception'].reraise()

    def get_batch_status(self, task_id_list):
        """ Request status for a batch of task_ids
        """
        assert isinstance(task_id_list,
                          list), "get_batch_status expects a list of task ids"

        pending_task_ids = [
            t for t in task_id_list if t not in self.func_table
        ]

        results = {}

        if pending_task_ids:
            payload = {'task_ids': pending_task_ids}
            r = self.post("/batch_status", json_body=payload)
            logger.debug("Response string : {}".format(r))

        pending_task_ids = set(pending_task_ids)

        for task_id in task_id_list:
            if task_id in pending_task_ids:
                try:
                    data = r['results'][task_id]
                    rets = self.update_table(data, task_id)
                    results[task_id] = rets
                except KeyError:
                    logger.debug(
                        "Task {} info was not available in the batch status")
                except Exception:
                    logger.exception(
                        "Failure while unpacking results fom get_batch_status")
            else:
                results[task_id] = self.func_table[task_id]

        return results

    def get_batch_result(self, task_id_list):
        """ Request results for a batch of task_ids
        """
        pass

    def run(self, *args, endpoint_id=None, function_id=None, **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        batch = self.create_batch()
        batch.add(*args,
                  endpoint_id=endpoint_id,
                  function_id=function_id,
                  **kwargs)
        r = self.batch_run(batch)
        """
        Create a future to deal with the result
        funcx_future = FuncXFuture(self, task_id, async_poll)

        if not asynchronous:
            return funcx_future.result()

        # Return the result
        return funcx_future
        """

        return r[0]

    def create_batch(self):
        """
        Create a Batch instance to handle batch submission in funcX

        Parameters
        ----------

        Returns
        -------
        Batch instance
            Status block containing "status" key.
        """
        batch = Batch()
        return batch

    def batch_run(self, batch):
        """Initiate a batch of tasks to funcX

        Parameters
        ----------
        batch: a Batch object

        Returns
        -------
        task_ids : a list of UUID strings that identify the tasks
        """
        servable_path = 'submit'
        assert isinstance(batch, Batch), "Requires a Batch object as input"
        assert len(batch.tasks) > 0, "Requires a non-empty batch"

        data = batch.prepare()

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)
        if r.get("status", "Failure") == "Failure":
            raise MalformedResponse("FuncX Request failed: {}".format(
                r.get("reason", "Unknown")))
        return r['task_uuids']

    def map_run(self,
                *args,
                endpoint_id=None,
                function_id=None,
                asynchronous=False,
                **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit_batch'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_kwargs = self.fx_serializer.serialize(kwargs)

        batch_payload = []
        iterator = args[0]
        for arg in iterator:
            ser_args = self.fx_serializer.serialize((arg, ))
            payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])
            batch_payload.append(payload)

        data = {
            'endpoints': [endpoint_id],
            'func': function_id,
            'payload': batch_payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status != 200:
            raise Exception(r)

        if r.get("status", "Failure") == "Failure":
            raise MalformedResponse("FuncX Request failed: {}".format(
                r.get("reason", "Unknown")))
        return r['task_uuids']

    def register_endpoint(self,
                          name,
                          endpoint_uuid,
                          metadata=None,
                          endpoint_version=None):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        metadata : dict
            endpoint metadata, see default_config example
        endpoint_version: str
            Version string to be passed to the webService as a compatibility check

        Returns
        -------
        A dict
            {'endopoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        self.version_check()

        data = {
            "endpoint_name": name,
            "endpoint_uuid": endpoint_uuid,
            "version": endpoint_version
        }
        if metadata:
            data['meta'] = metadata

        r = self.post(self.ep_registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        registration_path = 'get_containers'

        data = {"endpoint_name": name, "description": description}

        r = self.post(registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['endpoint_uuid'], r.data['endpoint_containers']

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        container_path = f'containers/{container_uuid}/{container_type}'

        r = self.get(container_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['container']

    def get_endpoint_status(self, endpoint_uuid):
        """Get the status reports for an endpoint.

        Parameters
        ----------
        endpoint_uuid : str
            UUID of the endpoint in question

        Returns
        -------
        dict
            The details of the endpoint's stats
        """
        stats_path = f'endpoints/{endpoint_uuid}/status'

        r = self.get(stats_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data

    def register_function(self,
                          function,
                          function_name=None,
                          container_uuid=None,
                          description=None,
                          public=False,
                          group=None,
                          searchable=True):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution
        function_name : str
            The entry point (function name) of the function. Default: None
        container_uuid : str
            Container UUID from registration with funcX
        description : str
            Description of the file
        public : bool
            Whether or not the function is publicly accessible. Default = False
        group : str
            A globus group uuid to share this function with
        searchable : bool
            If true, the function will be indexed into globus search with the appropriate permissions

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        registration_path = 'register_function'

        source_code = ""
        try:
            source_code = getsource(function)
        except OSError:
            logger.error(
                "Failed to find source code during function registration.")

        serialized_fn = self.fx_serializer.serialize(function)
        packed_code = self.fx_serializer.pack_buffers([serialized_fn])

        data = {
            "function_name": function.__name__,
            "function_code": packed_code,
            "function_source": source_code,
            "container_uuid": container_uuid,
            "entry_point":
            function_name if function_name else function.__name__,
            "description": description,
            "public": public,
            "group": group,
            "searchable": searchable
        }

        logger.info("Registering function : {}".format(data))

        r = self.post(registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        func_uuid = r.data['function_uuid']

        # Return the result
        return func_uuid

    def update_function(self, func_uuid, function):
        pass

    def search_function(self, q, offset=0, limit=10, advanced=False):
        """Search for function via the funcX service

        Parameters
        ----------
        q : str
            free-form query string
        offset : int
            offset into total results
        limit : int
            max number of results to return
        advanced : bool
            allows elastic-search like syntax in query string

        Returns
        -------
        FunctionSearchResults
        """
        return self.searcher.search_function(q,
                                             offset=offset,
                                             limit=limit,
                                             advanced=advanced)

    def search_endpoint(self, q, scope='all', owner_id=None):
        """

        Parameters
        ----------
        q
        scope : str
            Can be one of {'all', 'my-endpoints', 'shared-with-me'}
        owner_id
            should be urn like f"urn:globus:auth:identity:{owner_uuid}"

        Returns
        -------

        """
        return self.searcher.search_endpoint(q, scope=scope, owner_id=owner_id)

    def register_container(self,
                           location,
                           container_type,
                           name='',
                           description=''):
        """Register a container with the funcX service.

        Parameters
        ----------
        location : str
            The location of the container (e.g., its docker url). Required
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker). Required

        name : str
            A name for the container. Default = ''
        description : str
            A description to associate with the container. Default = ''

        Returns
        -------
        str
            The id of the container
        """
        container_path = 'containers'

        payload = {
            'name': name,
            'location': location,
            'description': description,
            'type': container_type
        }

        r = self.post(container_path, json_body=payload)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['container_id']

    def add_to_whitelist(self, endpoint_id, function_ids):
        """Adds the function to the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        req_path = f'endpoints/{endpoint_id}/whitelist'

        if not isinstance(function_ids, list):
            function_ids = [function_ids]

        payload = {'func': function_ids}

        r = self.post(req_path, json_body=payload)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r

    def get_whitelist(self, endpoint_id):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint

        Returns
        -------
        json
            The response of the request
        """
        req_path = f'endpoints/{endpoint_id}/whitelist'

        r = self.get(req_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r

    def delete_from_whitelist(self, endpoint_id, function_ids):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        if not isinstance(function_ids, list):
            function_ids = [function_ids]
        res = []
        for fid in function_ids:
            req_path = f'endpoints/{endpoint_id}/whitelist/{fid}'

            r = self.delete(req_path)
            if r.http_status != 200:
                raise HTTPError(r)
            res.append(r)

        # Return the result
        return res
Example #25
0
def _get_packed_code(
    func: t.Callable, serializer: t.Optional[FuncXSerializer] = None
) -> str:
    serializer = serializer if serializer else FuncXSerializer()
    return serializer.pack_buffers([serializer.serialize(func)])
Example #26
0
    def __init__(self,
                 config,
                 client_address="127.0.0.1",
                 interchange_address="127.0.0.1",
                 client_ports=(50055, 50056, 50057),
                 worker_ports=None,
                 worker_port_range=(54000, 55000),
                 cores_per_worker=1.0,
                 worker_debug=False,
                 launch_cmd=None,
                 heartbeat_threshold=60,
                 logdir=".",
                 logging_level=logging.INFO,
                 poll_period=10,
                 endpoint_id=None,
                 suppress_failure=False,
                 max_heartbeats_missed=2
                 ):
        """
        Parameters
        ----------
        config : funcx.Config object
             Funcx config object that describes how compute should be provisioned

        client_address : str
             The ip address at which the parsl client can be reached. Default: "127.0.0.1"

        interchange_address : str
             The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1"

        client_ports : triple(int, int, int)
             The ports at which the client can be reached

        launch_cmd : str
             TODO : update

        worker_ports : tuple(int, int)
             The specific two ports at which workers will connect to the Interchange. Default: None

        worker_port_range : tuple(int, int)
             The interchange picks ports at random from the range which will be used by workers.
             This is overridden when the worker_ports option is set. Defauls: (54000, 55000)

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0. Default=1

        worker_debug : Bool
             Enables worker debug logging.

        heartbeat_threshold : int
             Number of seconds since the last heartbeat after which worker is considered lost.

        logdir : str
             Parsl log directory paths. Logs and temp files go here. Default: '.'

        logging_level : int
             Logging level as defined in the logging module. Default: logging.INFO (20)

        endpoint_id : str
             Identity string that identifies the endpoint to the broker

        poll_period : int
             The main thread polling period, in milliseconds. Default: 10ms

        suppress_failure : Bool
             When set to True, the interchange will attempt to suppress failures. Default: False

        max_heartbeats_missed : int
             Number of heartbeats missed before setting kill_event

        """
        self.logdir = logdir
        try:
            os.makedirs(self.logdir)
        except FileExistsError:
            pass

        start_file_logger("{}/interchange.log".format(self.logdir), level=logging_level)
        logger.info("logger location {}".format(logger.handlers))
        logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id))
        self.config = config
        logger.info("Got config : {}".format(config))

        self.strategy = self.config.strategy
        self.client_address = client_address
        self.interchange_address = interchange_address
        self.suppress_failure = suppress_failure
        self.poll_period = poll_period

        self.serializer = FuncXSerializer()
        logger.info("Attempting connection to client at {} on ports: {},{},{}".format(
            client_address, client_ports[0], client_ports[1], client_ports[2]))
        self.context = zmq.Context()
        self.task_incoming = self.context.socket(zmq.DEALER)
        self.task_incoming.set_hwm(0)
        self.task_incoming.RCVTIMEO = 10  # in milliseconds
        logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0]))
        self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0]))

        self.results_outgoing = self.context.socket(zmq.DEALER)
        self.results_outgoing.set_hwm(0)
        logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1]))
        self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1]))

        self.command_channel = self.context.socket(zmq.DEALER)
        self.command_channel.RCVTIMEO = 1000  # in milliseconds
        # self.command_channel.set_hwm(0)
        logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2]))
        self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2]))
        logger.info("Connected to client")

        self.pending_task_queue = {}
        self.containers = {}
        self.total_pending_task_count = 0
        self.fxs = FuncXClient()

        logger.info("Interchange address is {}".format(self.interchange_address))
        self.worker_ports = worker_ports
        self.worker_port_range = worker_port_range

        self.task_outgoing = self.context.socket(zmq.ROUTER)
        self.task_outgoing.set_hwm(0)
        self.results_incoming = self.context.socket(zmq.ROUTER)
        self.results_incoming.set_hwm(0)

        # initalize the last heartbeat time to start the loop
        self.last_heartbeat = time.time()
        self.max_heartbeats_missed = max_heartbeats_missed

        self.endpoint_id = endpoint_id
        if self.worker_ports:
            self.worker_task_port = self.worker_ports[0]
            self.worker_result_port = self.worker_ports[1]

            self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port))
            self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port))

        else:
            self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*',
                                                                           min_port=worker_port_range[0],
                                                                           max_port=worker_port_range[1], max_tries=100)
            self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*',
                                                                                min_port=worker_port_range[0],
                                                                                max_port=worker_port_range[1], max_tries=100)

        logger.info("Bound to ports {},{} for incoming worker connections".format(
            self.worker_task_port, self.worker_result_port))

        self._ready_manager_queue = {}

        self.heartbeat_threshold = heartbeat_threshold
        self.blocks = {}  # type: Dict[str, str]
        self.block_id_map = {}
        self.launch_cmd = launch_cmd
        self.last_core_hr_counter = 0
        if not launch_cmd:
            self.launch_cmd = ("funcx-manager {debug} {max_workers} "
                               "-c {cores_per_worker} "
                               "--poll {poll_period} "
                               "--task_url={task_url} "
                               "--result_url={result_url} "
                               "--logdir={logdir} "
                               "--block_id={{block_id}} "
                               "--hb_period={heartbeat_period} "
                               "--hb_threshold={heartbeat_threshold} "
                               "--worker_mode={worker_mode} "
                               "--scheduler_mode={scheduler_mode} "
                               "--worker_type={{worker_type}} ")

        self.current_platform = {'parsl_v': PARSL_VERSION,
                                 'python_v': "{}.{}.{}".format(sys.version_info.major,
                                                               sys.version_info.minor,
                                                               sys.version_info.micro),
                                 'os': platform.system(),
                                 'hname': platform.node(),
                                 'dir': os.getcwd()}

        logger.info("Platform info: {}".format(self.current_platform))
        self._block_counter = 0
        try:
            self.load_config()
        except Exception as e:
            logger.exception("Caught exception")
            raise
Example #27
0
class Interchange(object):
    """ Interchange is a task orchestrator for distributed systems.

    1. Asynchronously queue large volume of tasks (>100K)
    2. Allow for workers to join and leave the union
    3. Detect workers that have failed using heartbeats
    4. Service single and batch requests from workers
    5. Be aware of requests worker resource capacity,
       eg. schedule only jobs that fit into walltime.

    TODO: We most likely need a PUB channel to send out global commands, like shutdown
    """
    def __init__(self,
                 config,
                 client_address="127.0.0.1",
                 interchange_address="127.0.0.1",
                 client_ports=(50055, 50056, 50057),
                 worker_ports=None,
                 worker_port_range=(54000, 55000),
                 cores_per_worker=1.0,
                 worker_debug=False,
                 launch_cmd=None,
                 heartbeat_threshold=60,
                 logdir=".",
                 logging_level=logging.INFO,
                 poll_period=10,
                 endpoint_id=None,
                 suppress_failure=False,
                 max_heartbeats_missed=2
                 ):
        """
        Parameters
        ----------
        config : funcx.Config object
             Funcx config object that describes how compute should be provisioned

        client_address : str
             The ip address at which the parsl client can be reached. Default: "127.0.0.1"

        interchange_address : str
             The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1"

        client_ports : triple(int, int, int)
             The ports at which the client can be reached

        launch_cmd : str
             TODO : update

        worker_ports : tuple(int, int)
             The specific two ports at which workers will connect to the Interchange. Default: None

        worker_port_range : tuple(int, int)
             The interchange picks ports at random from the range which will be used by workers.
             This is overridden when the worker_ports option is set. Defauls: (54000, 55000)

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0. Default=1

        worker_debug : Bool
             Enables worker debug logging.

        heartbeat_threshold : int
             Number of seconds since the last heartbeat after which worker is considered lost.

        logdir : str
             Parsl log directory paths. Logs and temp files go here. Default: '.'

        logging_level : int
             Logging level as defined in the logging module. Default: logging.INFO (20)

        endpoint_id : str
             Identity string that identifies the endpoint to the broker

        poll_period : int
             The main thread polling period, in milliseconds. Default: 10ms

        suppress_failure : Bool
             When set to True, the interchange will attempt to suppress failures. Default: False

        max_heartbeats_missed : int
             Number of heartbeats missed before setting kill_event

        """
        self.logdir = logdir
        try:
            os.makedirs(self.logdir)
        except FileExistsError:
            pass

        start_file_logger("{}/interchange.log".format(self.logdir), level=logging_level)
        logger.info("logger location {}".format(logger.handlers))
        logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id))
        self.config = config
        logger.info("Got config : {}".format(config))

        self.strategy = self.config.strategy
        self.client_address = client_address
        self.interchange_address = interchange_address
        self.suppress_failure = suppress_failure
        self.poll_period = poll_period

        self.serializer = FuncXSerializer()
        logger.info("Attempting connection to client at {} on ports: {},{},{}".format(
            client_address, client_ports[0], client_ports[1], client_ports[2]))
        self.context = zmq.Context()
        self.task_incoming = self.context.socket(zmq.DEALER)
        self.task_incoming.set_hwm(0)
        self.task_incoming.RCVTIMEO = 10  # in milliseconds
        logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0]))
        self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0]))

        self.results_outgoing = self.context.socket(zmq.DEALER)
        self.results_outgoing.set_hwm(0)
        logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1]))
        self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1]))

        self.command_channel = self.context.socket(zmq.DEALER)
        self.command_channel.RCVTIMEO = 1000  # in milliseconds
        # self.command_channel.set_hwm(0)
        logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2]))
        self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2]))
        logger.info("Connected to client")

        self.pending_task_queue = {}
        self.containers = {}
        self.total_pending_task_count = 0
        self.fxs = FuncXClient()

        logger.info("Interchange address is {}".format(self.interchange_address))
        self.worker_ports = worker_ports
        self.worker_port_range = worker_port_range

        self.task_outgoing = self.context.socket(zmq.ROUTER)
        self.task_outgoing.set_hwm(0)
        self.results_incoming = self.context.socket(zmq.ROUTER)
        self.results_incoming.set_hwm(0)

        # initalize the last heartbeat time to start the loop
        self.last_heartbeat = time.time()
        self.max_heartbeats_missed = max_heartbeats_missed

        self.endpoint_id = endpoint_id
        if self.worker_ports:
            self.worker_task_port = self.worker_ports[0]
            self.worker_result_port = self.worker_ports[1]

            self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port))
            self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port))

        else:
            self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*',
                                                                           min_port=worker_port_range[0],
                                                                           max_port=worker_port_range[1], max_tries=100)
            self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*',
                                                                                min_port=worker_port_range[0],
                                                                                max_port=worker_port_range[1], max_tries=100)

        logger.info("Bound to ports {},{} for incoming worker connections".format(
            self.worker_task_port, self.worker_result_port))

        self._ready_manager_queue = {}

        self.heartbeat_threshold = heartbeat_threshold
        self.blocks = {}  # type: Dict[str, str]
        self.block_id_map = {}
        self.launch_cmd = launch_cmd
        self.last_core_hr_counter = 0
        if not launch_cmd:
            self.launch_cmd = ("funcx-manager {debug} {max_workers} "
                               "-c {cores_per_worker} "
                               "--poll {poll_period} "
                               "--task_url={task_url} "
                               "--result_url={result_url} "
                               "--logdir={logdir} "
                               "--block_id={{block_id}} "
                               "--hb_period={heartbeat_period} "
                               "--hb_threshold={heartbeat_threshold} "
                               "--worker_mode={worker_mode} "
                               "--scheduler_mode={scheduler_mode} "
                               "--worker_type={{worker_type}} ")

        self.current_platform = {'parsl_v': PARSL_VERSION,
                                 'python_v': "{}.{}.{}".format(sys.version_info.major,
                                                               sys.version_info.minor,
                                                               sys.version_info.micro),
                                 'os': platform.system(),
                                 'hname': platform.node(),
                                 'dir': os.getcwd()}

        logger.info("Platform info: {}".format(self.current_platform))
        self._block_counter = 0
        try:
            self.load_config()
        except Exception as e:
            logger.exception("Caught exception")
            raise


    def load_config(self):
        """ Load the config
        """
        logger.info("Loading endpoint local config")
        working_dir = self.config.working_dir
        if self.config.working_dir is None:
            working_dir = "{}/{}".format(self.logdir, "worker_logs")
        logger.info("Setting working_dir: {}".format(working_dir))

        self.config.provider.script_dir = working_dir
        if hasattr(self.config.provider, 'channel'):
            self.config.provider.channel.script_dir = os.path.join(working_dir, 'submit_scripts')
            self.config.provider.channel.makedirs(self.config.provider.channel.script_dir, exist_ok=True)
            os.makedirs(self.config.provider.script_dir, exist_ok=True)

        debug_opts = "--debug" if self.config.worker_debug else ""
        max_workers = "" if self.config.max_workers_per_node == float('inf') \
                      else "--max_workers={}".format(self.config.max_workers_per_node)

        worker_task_url = f"tcp://{self.interchange_address}:{self.worker_task_port}"
        worker_result_url = f"tcp://{self.interchange_address}:{self.worker_result_port}"

        l_cmd = self.launch_cmd.format(debug=debug_opts,
                                       max_workers=max_workers,
                                       cores_per_worker=self.config.cores_per_worker,
                                       #mem_per_worker=self.config.mem_per_worker,
                                       prefetch_capacity=self.config.prefetch_capacity,
                                       task_url=worker_task_url,
                                       result_url=worker_result_url,
                                       nodes_per_block=self.config.provider.nodes_per_block,
                                       heartbeat_period=self.config.heartbeat_period,
                                       heartbeat_threshold=self.config.heartbeat_threshold,
                                       poll_period=self.config.poll_period,
                                       worker_mode=self.config.worker_mode,
                                       scheduler_mode=self.config.scheduler_mode,
                                       logdir=working_dir)
        self.launch_cmd = l_cmd
        logger.info("Launch command: {}".format(self.launch_cmd))

        if self.config.scaling_enabled:
            logger.info("Scaling ...")
            self.scale_out(self.config.provider.init_blocks)


    def get_tasks(self, count):
        """ Obtains a batch of tasks from the internal pending_task_queue

        Parameters
        ----------
        count: int
            Count of tasks to get from the queue

        Returns
        -------
        List of upto count tasks. May return fewer than count down to an empty list
            eg. [{'task_id':<x>, 'buffer':<buf>} ... ]
        """
        tasks = []
        for i in range(0, count):
            try:
                x = self.pending_task_queue.get(block=False)
            except queue.Empty:
                break
            else:
                tasks.append(x)

        return tasks

    def migrate_tasks_to_internal(self, kill_event, status_request):
        """Pull tasks from the incoming tasks 0mq pipe onto the internal
        pending task queue

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """
        logger.info("[TASK_PULL_THREAD] Starting")
        task_counter = 0
        poller = zmq.Poller()
        poller.register(self.task_incoming, zmq.POLLIN)

        while not kill_event.is_set():
            # Check when the last heartbeat was.
            # logger.debug(f"[TASK_PULL_THREAD] Last heartbeat: {self.last_heartbeat}")
            if int(time.time() - self.last_heartbeat) > (self.heartbeat_threshold * self.max_heartbeats_missed):
                logger.critical("[TASK_PULL_THREAD] Missed too many heartbeats. Setting kill event.")
                kill_event.set()
                break

            try:
                msg = self.task_incoming.recv_pyobj()
                self.last_heartbeat = time.time()
            except zmq.Again:
                # We just timed out while attempting to receive
                logger.debug("[TASK_PULL_THREAD] {} tasks in internal queue".format(self.total_pending_task_count))
                continue

            if msg == 'STOP':
                kill_event.set()
                break
            elif msg == 'STATUS_REQUEST':
                logger.info("Got STATUS_REQUEST")
                status_request.set()
            else:
                logger.info("[TASK_PULL_THREAD] Received task:{}".format(msg))
                task_type = self.get_container(msg['task_id'].split(";")[1])
                msg['container'] = task_type
                if task_type not in self.pending_task_queue:
                    self.pending_task_queue[task_type] = queue.Queue(maxsize=10 ** 6)
                self.pending_task_queue[task_type].put(msg)
                self.total_pending_task_count += 1
                logger.debug("[TASK_PULL_THREAD] pending task count: {}".format(self.total_pending_task_count))
                task_counter += 1
                logger.debug("[TASK_PULL_THREAD] Fetched task:{}".format(task_counter))

    def get_container(self, container_uuid):
        """ Get the container image location if it is not known to the interchange"""
        if container_uuid not in self.containers:
            if container_uuid == 'RAW' or not container_uuid:
                self.containers[container_uuid] = 'RAW'
            else:
                try:
                    container = self.fxs.get_container(container_uuid, self.config.container_type)
                except Exception:
                    logger.exception("[FETCH_CONTAINER] Unable to resolve container location")
                    self.containers[container_uuid] = 'RAW'
                else:
                    logger.info("[FETCH_CONTAINER] Got container info: {}".format(container))
                    self.containers[container_uuid] = container.get('location', 'RAW')
        return self.containers[container_uuid]

    def get_total_tasks_outstanding(self):
        """ Get the outstanding tasks in total
        """
        outstanding = {}
        for task_type in self.pending_task_queue:
            outstanding[task_type] = outstanding.get(task_type, 0) + self.pending_task_queue[task_type].qsize()
        for manager in self._ready_manager_queue:
            for task_type in self._ready_manager_queue[manager]['tasks']:
                outstanding[task_type] = outstanding.get(task_type, 0) + len(self._ready_manager_queue[manager]['tasks'][task_type])
        return outstanding

    def get_total_live_workers(self):
        """ Get the total active workers
        """
        active = 0
        for manager in self._ready_manager_queue:
            if self._ready_manager_queue[manager]['active']:
                active += self._ready_manager_queue[manager]['max_worker_count']
        return active

    def get_outstanding_breakdown(self):
        """ Get outstanding breakdown per manager and in the interchange queues

        Returns
        -------
        List of status for online elements
        [ (element, tasks_pending, status) ... ]
        """

        pending_on_interchange = self.total_pending_task_count
        # Reporting pending on interchange is a deviation from Parsl
        reply = [('interchange', pending_on_interchange, True)]
        for manager in self._ready_manager_queue:
            resp = (manager.decode('utf-8'),
                    sum([len(tids) for tids in self._ready_manager_queue[manager]['tasks'].values()]),
                    self._ready_manager_queue[manager]['active'])
            reply.append(resp)
        return reply

    def _hold_block(self, block_id):
        """ Sends hold command to all managers which are in a specific block

        Parameters
        ----------
        block_id : str
             Block identifier of the block to be put on hold
        """
        for manager in self._ready_manager_queue:
            if self._ready_manager_queue[manager]['active'] and \
               self._ready_manager_queue[manager]['block_id'] == block_id:
                logger.debug("[HOLD_BLOCK]: Sending hold to manager: {}".format(manager))
                self.hold_manager(manager)

    def hold_manager(self, manager):
        """ Put manager on hold
        Parameters
        ----------

        manager : str
          Manager id to be put on hold while being killed
        """
        if manager in self._ready_manager_queue:
            self._ready_manager_queue[manager]['active'] = False
            reply = True
        else:
            reply = False

    def _command_server(self, kill_event):
        """ Command server to run async command to the interchange
        """
        logger.debug("[COMMAND] Command Server Starting")

        while not kill_event.is_set():
            try:
                command_req = self.command_channel.recv_pyobj()
                logger.debug("[COMMAND] Received command request: {}".format(command_req))
                if command_req == "OUTSTANDING_C":
                    reply = self.get_total_outstanding()

                elif command_req == "MANAGERS":
                    reply = self.get_outstanding_breakdown()

                elif command_req.startswith("HOLD_WORKER"):
                    cmd, s_manager = command_req.split(';')
                    manager = s_manager.encode('utf-8')
                    logger.info("[CMD] Received HOLD_WORKER for {}".format(manager))
                    if manager in self._ready_manager_queue:
                        self._ready_manager_queue[manager]['active'] = False
                        reply = True
                    else:
                        reply = False

                elif command_req == "HEARTBEAT":
                    logger.info("[CMD] Received heartbeat message from hub")
                    reply = "HBT,{}".format(self.endpoint_id)

                elif command_req == "SHUTDOWN":
                    logger.info("[CMD] Received SHUTDOWN command")
                    kill_event.set()
                    reply = True

                else:
                    reply = None

                logger.debug("[COMMAND] Reply: {}".format(reply))
                self.command_channel.send_pyobj(reply)

            except zmq.Again:
                logger.debug("[COMMAND] is alive")
                continue

    def stop(self):
        """Prepare the interchange for shutdown"""
        self._kill_event.set()

        self._task_puller_thread.join()
        self._command_thread.join()

    def start(self, poll_period=None):
        """ Start the Interchange

        Parameters:
        ----------
        poll_period : int
           poll_period in milliseconds
        """
        logger.info("Incoming ports bound")

        if poll_period is None:
            poll_period = self.poll_period

        start = time.time()
        count = 0

        self._kill_event = threading.Event()
        self._status_request = threading.Event()
        self._task_puller_thread = threading.Thread(target=self.migrate_tasks_to_internal,
                                                    args=(self._kill_event, self._status_request, ))
        self._task_puller_thread.start()

        self._command_thread = threading.Thread(target=self._command_server,
                                                args=(self._kill_event, ))
        self._command_thread.start()

        try:
            logger.debug("Starting strategy.")
            self.strategy.start(self)
        except RuntimeError as e:
            # This is raised when re-registering an endpoint as strategy already exists
            logger.debug("Failed to start strategy.")
            logger.info(e)

        poller = zmq.Poller()
        # poller.register(self.task_incoming, zmq.POLLIN)
        poller.register(self.task_outgoing, zmq.POLLIN)
        poller.register(self.results_incoming, zmq.POLLIN)

        # These are managers which we should examine in an iteration
        # for scheduling a job (or maybe any other attention?).
        # Anything altering the state of the manager should add it
        # onto this list.
        interesting_managers = set()

        while not self._kill_event.is_set():
            self.socks = dict(poller.poll(timeout=poll_period))

            # Listen for requests for work
            if self.task_outgoing in self.socks and self.socks[self.task_outgoing] == zmq.POLLIN:
                logger.debug("[MAIN] starting task_outgoing section")
                message = self.task_outgoing.recv_multipart()
                manager = message[0]

                if manager not in self._ready_manager_queue:
                    reg_flag = False

                    try:
                        msg = json.loads(message[1].decode('utf-8'))
                        reg_flag = True
                    except Exception:
                        logger.warning("[MAIN] Got a non-json registration message from manager:{}".format(
                            manager))
                        logger.debug("[MAIN] Message :\n{}\n".format(message))

                    # By default we set up to ignore bad nodes/registration messages.
                    self._ready_manager_queue[manager] = {'last': time.time(),
                                                          'reg_time': time.time(),
                                                          'free_capacity': {'total_workers': 0},
                                                          'max_worker_count': 0,
                                                          'active': True,
                                                          'tasks': collections.defaultdict(set),
                                                          'total_tasks': 0}
                    if reg_flag is True:
                        interesting_managers.add(manager)
                        logger.info("[MAIN] Adding manager: {} to ready queue".format(manager))
                        self._ready_manager_queue[manager].update(msg)
                        logger.info("[MAIN] Registration info for manager {}: {}".format(manager, msg))

                        if (msg['python_v'].rsplit(".", 1)[0] != self.current_platform['python_v'].rsplit(".", 1)[0] or
                            msg['parsl_v'] != self.current_platform['parsl_v']):
                            logger.warn("[MAIN] Manager {} has incompatible version info with the interchange".format(manager))

                            if self.suppress_failure is False:
                                logger.debug("Setting kill event")
                                self._kill_event.set()
                                e = ManagerLost(manager)
                                result_package = {'task_id': -1,
                                                  'exception': self.serializer.serialize(e)}
                                pkl_package = pickle.dumps(result_package)
                                self.results_outgoing.send(pkl_package)
                                logger.warning("[MAIN] Sent failure reports, unregistering manager")
                            else:
                                logger.debug("[MAIN] Suppressing shutdown due to version incompatibility")

                    else:
                        # Registration has failed.
                        if self.suppress_failure is False:
                            logger.debug("Setting kill event for bad manager")
                            self._kill_event.set()
                            e = BadRegistration(manager, critical=True)
                            result_package = {'task_id': -1,
                                              'exception': self.serializer.serialize(e)}
                            pkl_package = pickle.dumps(result_package)
                            self.results_outgoing.send(pkl_package)
                        else:
                            logger.debug("[MAIN] Suppressing bad registration from manager:{}".format(
                                manager))

                else:
                    self._ready_manager_queue[manager]['last'] = time.time()
                    if message[1] == b'HEARTBEAT':
                        logger.debug("[MAIN] Manager {} sends heartbeat".format(manager))
                        self.task_outgoing.send_multipart([manager, b'', PKL_HEARTBEAT_CODE])
                    else:
                        manager_adv = pickle.loads(message[1])
                        logger.debug("[MAIN] Manager {} requested {}".format(manager, manager_adv))
                        self._ready_manager_queue[manager]['free_capacity'].update(manager_adv)
                        self._ready_manager_queue[manager]['free_capacity']['total_workers'] = sum(manager_adv.values())
                        interesting_managers.add(manager)

            # If we had received any requests, check if there are tasks that could be passed

            logger.debug("[MAIN] Managers count (total/interesting): {}/{}".format(
                len(self._ready_manager_queue),
                len(interesting_managers)))

            task_dispatch, dispatched_task = naive_interchange_task_dispatch(interesting_managers,
                                                                             self.pending_task_queue,
                                                                             self._ready_manager_queue,
                                                                             scheduler_mode=self.config.scheduler_mode)
            self.total_pending_task_count -= dispatched_task

            for manager in task_dispatch:
                tasks = task_dispatch[manager]
                if tasks:
                    logger.info("[MAIN] Sending task message {} to manager {}".format(tasks, manager))
                    self.task_outgoing.send_multipart([manager, b'', pickle.dumps(tasks)])

            # Receive any results and forward to client
            if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN:
                logger.debug("[MAIN] entering results_incoming section")
                manager, *b_messages = self.results_incoming.recv_multipart()
                if manager not in self._ready_manager_queue:
                    logger.warning("[MAIN] Received a result from a un-registered manager: {}".format(manager))
                else:
                    logger.info("[MAIN] Got {} result items in batch".format(len(b_messages)))
                    for b_message in b_messages:
                        r = pickle.loads(b_message)
                        # logger.debug("[MAIN] Received result for task {} from {}".format(r['task_id'], manager))
                        task_type = self.containers[r['task_id'].split(';')[1]]
                        self._ready_manager_queue[manager]['tasks'][task_type].remove(r['task_id'])
                    self._ready_manager_queue[manager]['total_tasks'] -= len(b_messages)
                    self.results_outgoing.send_multipart(b_messages)
                    logger.debug("[MAIN] Current tasks: {}".format(self._ready_manager_queue[manager]['tasks']))
                logger.debug("[MAIN] leaving results_incoming section")

            # logger.debug("[MAIN] entering bad_managers section")
            bad_managers = [manager for manager in self._ready_manager_queue if
                            time.time() - self._ready_manager_queue[manager]['last'] > self.heartbeat_threshold]
            for manager in bad_managers:
                logger.debug("[MAIN] Last: {} Current: {}".format(self._ready_manager_queue[manager]['last'], time.time()))
                logger.warning("[MAIN] Too many heartbeats missed for manager {}".format(manager))
                e = ManagerLost(manager)
                for task_type in self._ready_manager_queue[manager]['tasks']:
                    for tid in self._ready_manager_queue[manager]['tasks'][task_type]:
                        result_package = {'task_id': tid, 'exception': self.serializer.serialize(e)}
                        pkl_package = pickle.dumps(result_package)
                        self.results_outgoing.send(pkl_package)
                logger.warning("[MAIN] Sent failure reports, unregistering manager")
                self._ready_manager_queue.pop(manager, 'None')
                if manager in interesting_managers:
                    interesting_managers.remove(manager)
            logger.debug("[MAIN] ending one main loop iteration")

            if self._status_request.is_set():
                logger.info("status request response")
                result_package = self.get_status_report()
                pkl_package = pickle.dumps(result_package)
                self.results_outgoing.send(pkl_package)
                logger.info("[MAIN] Sent info response")
                self._status_request.clear()

        delta = time.time() - start
        logger.info("Processed {} tasks in {} seconds".format(count, delta))
        logger.warning("Exiting")

    def get_status_report(self):
        """ Get utilization numbers
        """
        total_cores = 0
        total_mem = 0
        core_hrs = 0
        active_managers = 0
        free_capacity = 0
        outstanding_tasks = self.get_total_tasks_outstanding()
        pending_tasks = self.total_pending_task_count
        num_managers = len(self._ready_manager_queue)
        live_workers = self.get_total_live_workers()
        
        for manager in self._ready_manager_queue:
            total_cores += self._ready_manager_queue[manager]['cores']
            total_mem += self._ready_manager_queue[manager]['mem']
            active_dur = abs(time.time() - self._ready_manager_queue[manager]['reg_time'])
            core_hrs += (active_dur * total_cores) / 3600
            if self._ready_manager_queue[manager]['active']:
                active_managers += 1
            free_capacity += self._ready_manager_queue[manager]['free_capacity']['total_workers']

        result_package = {'task_id': -2,
                          'info': {'total_cores': total_cores,
                                   'total_mem' : total_mem,
                                   'new_core_hrs': core_hrs - self.last_core_hr_counter,
                                   'total_core_hrs': round(core_hrs, 2),
                                   'managers': num_managers,
                                   'active_managers': active_managers,
                                   'total_workers': live_workers,
                                   'idle_workers': free_capacity,
                                   'pending_tasks': pending_tasks,
                                   'outstanding_tasks': outstanding_tasks,
                                   'worker_mode': self.config.worker_mode,
                                   'scheduler_mode': self.config.scheduler_mode,
                                   'scaling_enabled': self.config.scaling_enabled,
                                   'mem_per_worker': self.config.mem_per_worker,
                                   'cores_per_worker': self.config.cores_per_worker,
                                   'prefetch_capacity': self.config.prefetch_capacity,
                                   'max_blocks': self.config.provider.max_blocks,
                                   'min_blocks': self.config.provider.min_blocks,
                                   'max_workers_per_node': self.config.max_workers_per_node,
                                   'nodes_per_block': self.config.provider.nodes_per_block
        }}

        self.last_core_hr_counter = core_hrs
        return result_package

    def scale_out(self, blocks=1, task_type=None):
        """Scales out the number of blocks by "blocks"

        Raises:
             NotImplementedError
        """
        r = []
        for i in range(blocks):
            if self.config.provider:
                self._block_counter += 1
                external_block_id = str(self._block_counter)
                if not task_type and self.config.scheduler_mode == 'hard':
                    launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type='RAW')
                else:
                    launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type=task_type)
                if not task_type:
                    internal_block = self.config.provider.submit(launch_cmd, 1)
                else:
                    internal_block = self.config.provider.submit(launch_cmd, 1, task_type)
                logger.debug("Launched block {}->{}".format(external_block_id, internal_block))
                if not internal_block:
                    raise(ScalingFailed(self.provider.label,
                                        "Attempts to provision nodes via provider has failed"))
                self.blocks[external_block_id] = internal_block
                self.block_id_map[internal_block] = external_block_id
            else:
                logger.error("No execution provider available")
                r = None
        return r

    def scale_in(self, blocks=None, block_ids=[], task_type=None):
        """Scale in the number of active blocks by specified amount.

        Parameters
        ----------
        blocks : int
            # of blocks to terminate

        block_ids : [str.. ]
            List of external block ids to terminate
        """
        if task_type:
            logger.info("Scaling in blocks of specific task type {}. Let the provider decide which to kill".format(task_type))
            if self.config.scaling_enabled and self.config.provider:
                to_kill, r = self.config.provider.cancel(blocks, task_type)
                logger.info("Get the killed blocks: {}, and status: {}".format(to_kill, r))
                for job in to_kill:
                    logger.info("[scale_in] Getting the block_id map {} for job {}".format(self.block_id_map, job))
                    block_id = self.block_id_map[job]
                    logger.info("[scale_in] Holding block {}".format(block_id))
                    self._hold_block(block_id)
                    self.blocks.pop(block_id)
                return r

        if block_ids:
            block_ids_to_kill = block_ids
        else:
            block_ids_to_kill = list(self.blocks.keys())[:blocks]

        # Try a polite terminate
        # TODO : Missing logic to hold blocks
        for block_id in block_ids_to_kill:
            self._hold_block(block_id)

        # Now kill via provider
        to_kill = [self.blocks.pop(bid) for bid in block_ids_to_kill]

        if self.config.scaling_enabled and self.config.provider:
            r = self.config.provider.cancel(to_kill)

        return r

    def provider_status(self):
        """ Get status of all blocks from the provider
        """
        status = []
        if self.config.provider:
            logger.debug("[MAIN] Getting the status of {} blocks.".format(list(self.blocks.values())))
            status = self.config.provider.status(list(self.blocks.values()))
            logger.debug("[MAIN] The status is {}".format(status))

        return status
Example #28
0
class FuncXClient(BaseClient):
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c'

    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://dev.funcx.org/api/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        service_address: str
        The address of the funcX web service to communicate with.
        Default: https://dev.funcx.org/api/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        native_client = NativeClient(client_id=self.CLIENT_ID)

        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"

        if not fx_authorizer:
            native_client.login(
                requested_scopes=[fx_scope],
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = native_client.get_authorizers_by_scope(
                requested_scopes=[fx_scope])
            fx_authorizer = all_authorizers[fx_scope]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

    def logout(self):
        """Remove credentials from your local system
        """
        logout()

    def get_task_status(self, task_id):
        """Get the status of a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Status block containing "status" key.
        """

        r = self.get("{task_id}/status".format(task_id=task_id))
        return json.loads(r.text)

    def get_result(self, task_id):
        """ Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """

        r = self.get("{task_id}/status".format(task_id=task_id))

        logger.info(f"Got from globus : {r}")
        r_dict = json.loads(r.text)

        if 'result' in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict['result'])
            except Exception:
                raise Exception(
                    "Failure during deserialization of the result object")
            else:
                return r_obj

        elif 'exception' in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(
                    r_dict['exception'])
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise Exception(
                    "Failure during deserialization of the Task's exception object"
                )
            else:
                r_exception.reraise()

        else:
            raise Exception("Task pending")

    def run(self,
            *args,
            endpoint_id=None,
            function_id=None,
            asynchronous=False,
            **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_args = self.fx_serializer.serialize(args)
        ser_kwargs = self.fx_serializer.serialize(kwargs)
        payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])

        data = {
            'endpoint': endpoint_id,
            'func': function_id,
            'payload': payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        if 'task_uuid' not in r:
            raise MalformedResponse(r)
        """
        Create a future to deal with the result
        funcx_future = FuncXFuture(self, task_id, async_poll)

        if not asynchronous:
            return funcx_future.result()

        # Return the result
        return funcx_future
        """
        return r['task_uuid']

    def register_endpoint(self, name, endpoint_uuid, description=None):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        A dict
            {'endopoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        data = {
            "endpoint_name": name,
            "endpoint_uuid": endpoint_uuid,
            "description": description
        }

        r = self.post(self.ep_registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        registration_path = 'get_containers'

        data = {"endpoint_name": name, "description": description}

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['endpoint_uuid'], r.data['endpoint_containers']

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        container_path = f'containers/{container_uuid}/{container_type}'

        r = self.get(container_path)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['container']

    def register_function(self,
                          function,
                          function_name=None,
                          container_uuid=None,
                          description=None):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution

        function_name : str
            The entry point (function name) of the function. Default: None

        container_uuid : str
            Container UUID from registration with funcX

        description : str
            Description of the file

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        registration_path = 'register_function'

        serialized_fn = self.fx_serializer.serialize(function)
        packed_code = self.fx_serializer.pack_buffers([serialized_fn])

        data = {
            "function_name": function.__name__,
            "function_code": packed_code,
            "container_uuid": container_uuid,
            "entry_point":
            function_name if function_name else function.__name__,
            "description": description
        }

        logger.info("Registering function : {}".format(data))

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['function_uuid']
Example #29
0
from funcx_endpoint.executors.high_throughput.messages import Task


def double(x):
    return x * 2


if __name__ == "__main__":

    results_queue = Queue()
    #    set_file_logger('executor.log', name='funcx_endpoint', level=logging.DEBUG)
    htex = HighThroughputExecutor(interchange_local=True, passthrough=True)

    htex.start(results_passthrough=results_queue)
    htex._start_remote_interchange_process()
    fx_serializer = FuncXSerializer()

    for i in range(10):
        task_id = str(uuid.uuid4())
        args = (i, )
        kwargs = {}

        fn_code = fx_serializer.serialize(double)
        ser_code = fx_serializer.pack_buffers([fn_code])
        ser_params = fx_serializer.pack_buffers(
            [fx_serializer.serialize(args),
             fx_serializer.serialize(kwargs)])

        payload = Task(task_id, "RAW", ser_code + ser_params)
        f = htex.submit_raw(payload.pack())
        time.sleep(0.5)
Example #30
0
to be addressed.
"""

from concurrent.futures import Future
import os
import logging
import threading
import queue
import pickle
import daemon
from multiprocessing import Process, Queue

#from ipyparallel.serialize import pack_apply_message  # ,unpack_apply_message
from ipyparallel.serialize import deserialize_object  # ,serialize_object
from funcx.serialize import FuncXSerializer
fx_serializer = FuncXSerializer()

from parsl.executors.high_throughput import interchange
from parsl.executors.errors import *
from parsl.executors.base import ParslExecutor
from parsl.dataflow.error import ConfigurationError

from parsl.utils import RepresentationMixin
from parsl.providers import LocalProvider


from funcx.executors.high_throughput import zmq_pipes

logger = logging.getLogger(__name__)

BUFFER_THRESHOLD = 1024 * 1024