示例#1
0
class Manager(object):
    """ Manager manages task execution by the workers

                |         0mq              |    Manager         |   Worker Processes
                |                          |                    |
                | <-----Request N task-----+--Count task reqs   |      Request task<--+
    Interchange | -------------------------+->Receive task batch|          |          |
                |                          |  Distribute tasks--+----> Get(block) &   |
                |                          |                    |      Execute task   |
                |                          |                    |          |          |
                | <------------------------+--Return results----+----  Post result    |
                |                          |                    |          |          |
                |                          |                    |          +----------+
                |                          |                IPC-Qeueues

    """
    def __init__(
            self,
            task_q_url="tcp://127.0.0.1:50097",
            result_q_url="tcp://127.0.0.1:50098",
            max_queue_size=10,
            cores_per_worker=1,
            max_workers=float('inf'),
            uid=None,
            heartbeat_threshold=120,
            heartbeat_period=30,
            logdir=None,
            debug=False,
            block_id=None,
            internal_worker_port_range=(50000, 60000),
            mode="singularity_reuse",
            container_image=None,
            # TODO : This should be 10ms
            poll_period=100):
        """
        Parameters
        ----------
        worker_url : str
             Worker url on which workers will attempt to connect back

        uid : str
             string unique identifier

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0. Default=1

        max_workers : int
             caps the maximum number of workers that can be launched.
             default: infinity

        heartbeat_threshold : int
             Seconds since the last message from the interchange after which the
             interchange is assumed to be un-available, and the manager initiates shutdown. Default:120s

             Number of seconds since the last message from the interchange after which the worker
             assumes that the interchange is lost and the manager shuts down. Default:120

        heartbeat_period : int
             Number of seconds after which a heartbeat message is sent to the interchange

        internal_worker_port_range : tuple(int, int)
             Port range from which the port(s) for the workers to connect to the manager is picked.
             Default: (50000,60000)

        mode : str
             Pick between 3 supported modes for the worker:
              1. no_container : Worker launched without containers
              2. singularity_reuse : Worker launched inside a singularity container that will be reused
              3. singularity_single_use : Each worker and task runs inside a new container instance.

        container_image : str
             Path or identifier for the container to be used. Default: None

        poll_period : int
             Timeout period used by the manager in milliseconds. Default: 10ms
        """

        logger.info("Manager started")

        self.context = zmq.Context()
        self.task_incoming = self.context.socket(zmq.DEALER)
        self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        # Linger is set to 0, so that the manager can exit even when there might be
        # messages in the pipe
        self.task_incoming.setsockopt(zmq.LINGER, 0)
        self.task_incoming.connect(task_q_url)

        self.logdir = logdir
        self.debug = debug
        self.block_id = block_id
        self.result_outgoing = self.context.socket(zmq.DEALER)
        self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        self.result_outgoing.setsockopt(zmq.LINGER, 0)
        self.result_outgoing.connect(result_q_url)
        logger.info("Manager connected")

        self.uid = uid

        self.mode = mode
        self.container_image = container_image
        self.cores_on_node = multiprocessing.cpu_count()
        self.max_workers = max_workers
        self.cores_per_workers = cores_per_worker
        self.available_mem_on_node = round(
            psutil.virtual_memory().available / (2**30), 1)
        self.worker_count = min(
            max_workers, math.floor(self.cores_on_node / cores_per_worker))
        self.worker_map = WorkerMap(self.worker_count)

        self.internal_worker_port_range = internal_worker_port_range

        self.funcx_task_socket = self.context.socket(zmq.ROUTER)
        self.funcx_task_socket.set_hwm(0)
        self.address = '127.0.0.1'
        self.worker_port = self.funcx_task_socket.bind_to_random_port(
            "tcp://*",
            min_port=self.internal_worker_port_range[0],
            max_port=self.internal_worker_port_range[1])

        logger.info(
            "Manager listening on {} port for incoming worker connections".
            format(self.worker_port))

        self.task_queues = {'RAW': queue.Queue()}

        self.pending_result_queue = multiprocessing.Queue()

        self.max_queue_size = max_queue_size + self.worker_count
        self.tasks_per_round = 1

        self.heartbeat_period = heartbeat_period
        self.heartbeat_threshold = heartbeat_threshold
        self.poll_period = poll_period
        self.serializer = FuncXSerializer()
        self.next_worker_q = []  # FIFO queue for spinning up workers.

    def create_reg_message(self):
        """ Creates a registration message to identify the worker to the interchange
        """
        msg = {
            'parsl_v':
            PARSL_VERSION,
            'python_v':
            "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor,
                              sys.version_info.micro),
            'worker_count':
            self.worker_count,
            'cores':
            self.cores_on_node,
            'mem':
            self.available_mem_on_node,
            'block_id':
            self.block_id,
            'os':
            platform.system(),
            'hname':
            platform.node(),
            'dir':
            os.getcwd(),
        }
        b_msg = json.dumps(msg).encode('utf-8')
        return b_msg

    def heartbeat(self):
        """ Send heartbeat to the incoming task queue
        """
        heartbeat = (HEARTBEAT_CODE).to_bytes(4, "little")
        r = self.task_incoming.send(heartbeat)
        logger.debug("Return from heartbeat: {}".format(r))

    def pull_tasks(self, kill_event):
        """ Pull tasks from the incoming tasks 0mq pipe onto the internal
        pending task queue


        While :
            receive results and task requests from the workers
            receive tasks/heartbeats from the Interchange
            match tasks to workers
            if task doesn't have appropriate worker type:
                 launch worker of type.. with LRU or some sort of caching strategy.
            if workers >> tasks:
                 advertize available capacity

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """
        logger.info("[TASK PULL THREAD] starting")
        poller = zmq.Poller()
        poller.register(self.task_incoming, zmq.POLLIN)
        poller.register(self.funcx_task_socket, zmq.POLLIN)

        # Send a registration message
        msg = self.create_reg_message()
        logger.debug("Sending registration message: {}".format(msg))
        self.task_incoming.send(msg)
        last_beat = time.time()
        last_interchange_contact = time.time()
        task_recv_counter = 0
        task_done_counter = 0

        poll_timer = self.poll_period

        new_worker_map = None
        while not kill_event.is_set():
            # Disabling the check on ready_worker_queue disables batching
            logger.debug("[TASK_PULL_THREAD] Loop start")
            pending_task_count = task_recv_counter - task_done_counter
            ready_worker_count = self.worker_map.ready_worker_count()
            logger.debug(
                "[TASK_PULL_THREAD pending_task_count: {} Ready_worker_count: {}"
                .format(pending_task_count, ready_worker_count))

            if time.time() > last_beat + self.heartbeat_period:
                self.heartbeat()
                last_beat = time.time()

            if pending_task_count < self.max_queue_size and ready_worker_count > 0:
                logger.debug("[TASK_PULL_THREAD] Requesting tasks: {}".format(
                    ready_worker_count))
                msg = (ready_worker_count.to_bytes(4, "little"))
                self.task_incoming.send(msg)

            # Receive results from the workers, if any
            socks = dict(poller.poll(timeout=poll_timer))
            if self.funcx_task_socket in socks and socks[
                    self.funcx_task_socket] == zmq.POLLIN:
                try:
                    w_id, m_type, message = self.funcx_task_socket.recv_multipart(
                    )
                    if m_type == b'REGISTER':
                        reg_info = pickle.loads(message)
                        logger.debug(
                            "Registration received from worker:{} {}".format(
                                w_id, reg_info))

                        # Increment worker_type count by 1
                        self.worker_map.pending_workers -= 1
                        self.worker_map.active_workers += 1
                        self.worker_map.register_worker(
                            w_id, reg_info['worker_type'])

                    elif m_type == b'TASK_RET':
                        logger.debug(
                            "Result received from worker: {}".format(w_id))
                        logger.debug(
                            "[TASK_PULL_THREAD] Got result: {}".format(
                                message))
                        self.pending_result_queue.put(message)
                        self.worker_map.put_worker(w_id)
                        task_done_counter += 1

                    elif m_type == b'WRKR_DIE':
                        logger.debug(
                            "[WORKER_REMOVE] Removing worker from worker_map..."
                        )
                        logger.debug("Ready worker counts: {}".format(
                            self.worker_map.ready_worker_type_counts))
                        logger.debug("Total worker counts: {}".format(
                            self.worker_map.total_worker_type_counts))
                        self.worker_map.remove_worker(w_id)

                except Exception as e:
                    logger.warning(
                        "[TASK_PULL_THREAD] FUNCX : caught {}".format(e))

            # Spin up any new workers according to the worker queue.
            # Returns the total number of containers that have spun up.
            spin_up = self.worker_map.spin_up_workers(
                self.next_worker_q,
                debug=self.debug,
                address=self.address,
                uid=self.uid,
                logdir=self.logdir,
                worker_port=self.worker_port)
            logger.debug("[SPIN UP]: Spun up {} containers".format(spin_up))

            # Receive task batches from Interchange and forward to workers
            if self.task_incoming in socks and socks[
                    self.task_incoming] == zmq.POLLIN:
                poll_timer = 0
                _, pkl_msg = self.task_incoming.recv_multipart()
                tasks = pickle.loads(pkl_msg)
                last_interchange_contact = time.time()

                if tasks == 'STOP':
                    logger.critical("[TASK_PULL_THREAD] Received stop request")
                    kill_event.set()
                    break

                elif tasks == HEARTBEAT_CODE:
                    logger.debug("Got heartbeat from interchange")

                else:
                    task_recv_counter += len(tasks)
                    logger.debug(
                        "[TASK_PULL_THREAD] Got tasks: {} of {}".format(
                            [t['task_id'] for t in tasks], task_recv_counter))

                    for task in tasks:
                        # Set default type to raw
                        task_type = task['task_id'].split(';')[1]

                        logger.debug("[TASK DEBUG] Task is of type: {}".format(
                            task_type))

                        if task_type not in self.task_queues:
                            self.task_queues[task_type] = queue.Queue()
                            self.worker_map.total_worker_type_counts[
                                task_type] = 0
                        self.task_queues[task_type].put(task)
                        logger.debug(
                            "Task {} pushed to a task queue {}".format(
                                task, task_type))

            else:
                logger.debug("[TASK_PULL_THREAD] No incoming tasks")
                # Limit poll duration to heartbeat_period
                # heartbeat_period is in s vs poll_timer in ms
                if not poll_timer:
                    poll_timer = self.poll_period
                poll_timer = min(self.heartbeat_period * 1000, poll_timer * 2)

                # Only check if no messages were received.
                if time.time(
                ) > last_interchange_contact + self.heartbeat_threshold:
                    logger.critical(
                        "[TASK_PULL_THREAD] Missing contact with interchange beyond heartbeat_threshold"
                    )
                    kill_event.set()
                    logger.critical("[TASK_PULL_THREAD] Exiting")
                    break

            logger.debug("Task queues: {}".format(self.task_queues))
            logger.debug("To-Die Counts: {}".format(
                self.worker_map.to_die_count))
            logger.debug("Alive worker counts: {}".format(
                self.worker_map.total_worker_type_counts))

            new_worker_map = naive_scheduler(self.task_queues,
                                             self.worker_count,
                                             new_worker_map,
                                             self.worker_map.to_die_count,
                                             logger=logger)
            logger.debug(
                "[SCHEDULER] New worker map: {}".format(new_worker_map))

            #  Count the workers of each type that need to be removed
            if new_worker_map is not None:
                spin_downs = self.worker_map.spin_down_workers(new_worker_map)

                for w_type in spin_downs:
                    self.remove_worker_init(w_type)

            # NOTE: Wipes the queue -- previous scheduling loops don't affect what's needed now.
            if new_worker_map is not None:
                self.next_worker_q = self.worker_map.get_next_worker_q(
                    new_worker_map)

            current_worker_map = self.worker_map.get_worker_counts()
            for task_type in current_worker_map:
                if task_type == 'slots':
                    continue

                # *** Match tasks to workers *** #
                else:
                    available_workers = current_worker_map[task_type]
                    logger.debug("Available workers of type {}: {}".format(
                        task_type, available_workers))

                    for i in range(available_workers):
                        if task_type in self.task_queues and not self.task_queues[task_type].qsize() == 0 \
                                and not self.worker_map.worker_queues[task_type].qsize() == 0:

                            logger.debug(
                                "Task type {} has task queue size {}".format(
                                    task_type,
                                    self.task_queues[task_type].qsize()))
                            logger.debug(
                                "... and available workers: {}".format(
                                    self.worker_map.worker_queues[task_type].
                                    qsize()))

                            task = self.task_queues[task_type].get()
                            worker_id = self.worker_map.get_worker(task_type)

                            logger.debug("Sending task {} to {}".format(
                                task['task_id'], worker_id))
                            to_send = [
                                worker_id,
                                pickle.dumps(task['task_id']), task['buffer']
                            ]
                            self.funcx_task_socket.send_multipart(to_send)
                            logger.debug("Sending complete!")

    def push_results(self, kill_event, max_result_batch_size=1):
        """ Listens on the pending_result_queue and sends out results via 0mq

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """

        logger.debug("[RESULT_PUSH_THREAD] Starting thread")

        push_poll_period = max(
            10,
            self.poll_period) / 1000  # push_poll_period must be atleast 10 ms
        logger.debug("[RESULT_PUSH_THREAD] push poll period: {}".format(
            push_poll_period))

        last_beat = time.time()
        items = []

        while not kill_event.is_set():
            try:
                r = self.pending_result_queue.get(block=True,
                                                  timeout=push_poll_period)
                items.append(r)
            except queue.Empty:
                pass
            except Exception as e:
                logger.exception(
                    "[RESULT_PUSH_THREAD] Got an exception: {}".format(e))

            # If we have reached poll_period duration or timer has expired, we send results
            if len(items) >= self.max_queue_size or time.time(
            ) > last_beat + push_poll_period:
                last_beat = time.time()
                if items:
                    self.result_outgoing.send_multipart(items)
                    items = []

        logger.critical("[RESULT_PUSH_THREAD] Exiting")

    def remove_worker_init(self, worker_type):
        """
            Kill/Remove a worker of a given worker_type.

            Add a kill message to the task_type queue.

            Assumption : All workers of the same type are uniform, and therefore don't discriminate when killing.
        """

        logger.debug("[WORKER_REMOVE] Appending KILL message to worker queue")
        self.worker_map.to_die_count[worker_type] += 1
        self.task_queues[worker_type].put({
            "task_id": pickle.dumps(b"KILL"),
            "buffer": b'KILL'
        })

    def start(self):
        """
        * while True:
            Receive tasks and start appropriate workers
            Push tasks to available workers
            Forward results
        """

        self.task_queues = {
            'RAW': queue.Queue()
        }  # k-v: task_type - task_q (PriorityQueue) -- default = RAW

        self.workers = [
            self.worker_map.add_worker(worker_id=str(
                self.worker_map.worker_counter),
                                       worker_type='RAW',
                                       address=self.address,
                                       debug=self.debug,
                                       uid=self.uid,
                                       logdir=self.logdir,
                                       worker_port=self.worker_port)
        ]
        self.worker_map.worker_counter += 1
        self.worker_map.pending_workers += 1

        logger.debug("Initial workers launched")
        self._kill_event = threading.Event()
        self._result_pusher_thread = threading.Thread(
            target=self.push_results, args=(self._kill_event, ))
        self._result_pusher_thread.start()

        self.pull_tasks(self._kill_event)
        logger.info("Waiting")
示例#2
0
    def __init__(
            self,
            task_q_url="tcp://127.0.0.1:50097",
            result_q_url="tcp://127.0.0.1:50098",
            max_queue_size=10,
            cores_per_worker=1,
            max_workers=float('inf'),
            uid=None,
            heartbeat_threshold=120,
            heartbeat_period=30,
            logdir=None,
            debug=False,
            block_id=None,
            internal_worker_port_range=(50000, 60000),
            mode="singularity_reuse",
            container_image=None,
            # TODO : This should be 10ms
            poll_period=100):
        """
        Parameters
        ----------
        worker_url : str
             Worker url on which workers will attempt to connect back

        uid : str
             string unique identifier

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0. Default=1

        max_workers : int
             caps the maximum number of workers that can be launched.
             default: infinity

        heartbeat_threshold : int
             Seconds since the last message from the interchange after which the
             interchange is assumed to be un-available, and the manager initiates shutdown. Default:120s

             Number of seconds since the last message from the interchange after which the worker
             assumes that the interchange is lost and the manager shuts down. Default:120

        heartbeat_period : int
             Number of seconds after which a heartbeat message is sent to the interchange

        internal_worker_port_range : tuple(int, int)
             Port range from which the port(s) for the workers to connect to the manager is picked.
             Default: (50000,60000)

        mode : str
             Pick between 3 supported modes for the worker:
              1. no_container : Worker launched without containers
              2. singularity_reuse : Worker launched inside a singularity container that will be reused
              3. singularity_single_use : Each worker and task runs inside a new container instance.

        container_image : str
             Path or identifier for the container to be used. Default: None

        poll_period : int
             Timeout period used by the manager in milliseconds. Default: 10ms
        """

        logger.info("Manager started")

        self.context = zmq.Context()
        self.task_incoming = self.context.socket(zmq.DEALER)
        self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        # Linger is set to 0, so that the manager can exit even when there might be
        # messages in the pipe
        self.task_incoming.setsockopt(zmq.LINGER, 0)
        self.task_incoming.connect(task_q_url)

        self.logdir = logdir
        self.debug = debug
        self.block_id = block_id
        self.result_outgoing = self.context.socket(zmq.DEALER)
        self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        self.result_outgoing.setsockopt(zmq.LINGER, 0)
        self.result_outgoing.connect(result_q_url)
        logger.info("Manager connected")

        self.uid = uid

        self.mode = mode
        self.container_image = container_image
        self.cores_on_node = multiprocessing.cpu_count()
        self.max_workers = max_workers
        self.cores_per_workers = cores_per_worker
        self.available_mem_on_node = round(
            psutil.virtual_memory().available / (2**30), 1)
        self.worker_count = min(
            max_workers, math.floor(self.cores_on_node / cores_per_worker))
        self.worker_map = WorkerMap(self.worker_count)

        self.internal_worker_port_range = internal_worker_port_range

        self.funcx_task_socket = self.context.socket(zmq.ROUTER)
        self.funcx_task_socket.set_hwm(0)
        self.address = '127.0.0.1'
        self.worker_port = self.funcx_task_socket.bind_to_random_port(
            "tcp://*",
            min_port=self.internal_worker_port_range[0],
            max_port=self.internal_worker_port_range[1])

        logger.info(
            "Manager listening on {} port for incoming worker connections".
            format(self.worker_port))

        self.task_queues = {'RAW': queue.Queue()}

        self.pending_result_queue = multiprocessing.Queue()

        self.max_queue_size = max_queue_size + self.worker_count
        self.tasks_per_round = 1

        self.heartbeat_period = heartbeat_period
        self.heartbeat_threshold = heartbeat_threshold
        self.poll_period = poll_period
        self.serializer = FuncXSerializer()
        self.next_worker_q = []  # FIFO queue for spinning up workers.
示例#3
0
class Manager(object):
    """ Manager manages task execution by the workers

                |         0mq              |    Manager         |   Worker Processes
                |                          |                    |
                | <-----Request N task-----+--Count task reqs   |      Request task<--+
    Interchange | -------------------------+->Receive task batch|          |          |
                |                          |  Distribute tasks--+----> Get(block) &   |
                |                          |                    |      Execute task   |
                |                          |                    |          |          |
                | <------------------------+--Return results----+----  Post result    |
                |                          |                    |          |          |
                |                          |                    |          +----------+
                |                          |                IPC-Qeueues

    """
    def __init__(
            self,
            task_q_url="tcp://127.0.0.1:50097",
            result_q_url="tcp://127.0.0.1:50098",
            max_queue_size=10,
            cores_per_worker=1,
            max_workers=float('inf'),
            uid=None,
            heartbeat_threshold=120,
            heartbeat_period=30,
            logdir=None,
            debug=False,
            internal_worker_port_range=(50000, 60000),
            mode="singularity_reuse",
            container_image=None,
            # TODO : This should be 10ms
            poll_period=100):
        """
        Parameters
        ----------
        worker_url : str
             Worker url on which workers will attempt to connect back

        uid : str
             string unique identifier

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0. Default=1

        max_workers : int
             caps the maximum number of workers that can be launched.
             default: infinity

        heartbeat_threshold : int
             Seconds since the last message from the interchange after which the
             interchange is assumed to be un-available, and the manager initiates shutdown. Default:120s

             Number of seconds since the last message from the interchange after which the worker
             assumes that the interchange is lost and the manager shuts down. Default:120

        heartbeat_period : int
             Number of seconds after which a heartbeat message is sent to the interchange

        internal_worker_port_range : tuple(int, int)
             Port range from which the port(s) for the workers to connect to the manager is picked.
             Default: (50000,60000)

        mode : str
             Pick between 3 supported modes for the worker:
              1. no_container : Worker launched without containers
              2. singularity_reuse : Worker launched inside a singularity container that will be reused
              3. singularity_single_use : Each worker and task runs inside a new container instance.

        container_image : str
             Path or identifier for the container to be used. Default: None

        poll_period : int
             Timeout period used by the manager in milliseconds. Default: 10ms
        """

        logger.info("Manager started")

        self.context = zmq.Context()
        self.task_incoming = self.context.socket(zmq.DEALER)
        self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        # Linger is set to 0, so that the manager can exit even when there might be
        # messages in the pipe
        self.task_incoming.setsockopt(zmq.LINGER, 0)
        self.task_incoming.connect(task_q_url)

        self.logdir = logdir
        self.debug = debug
        self.result_outgoing = self.context.socket(zmq.DEALER)
        self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8'))
        self.result_outgoing.setsockopt(zmq.LINGER, 0)
        self.result_outgoing.connect(result_q_url)
        logger.info("Manager connected")

        self.uid = uid

        self.mode = mode
        self.container_image = container_image
        cores_on_node = multiprocessing.cpu_count()
        self.max_workers = max_workers
        self.worker_count = min(max_workers,
                                math.floor(cores_on_node / cores_per_worker))
        self.worker_map = WorkerMap(self.worker_count)
        logger.info("Manager will spawn {} workers".format(self.worker_count))

        self.internal_worker_port_range = internal_worker_port_range

        self.funcx_task_socket = self.context.socket(zmq.ROUTER)
        self.funcx_task_socket.set_hwm(0)
        self.address = '127.0.0.1'
        self.worker_port = self.funcx_task_socket.bind_to_random_port(
            "tcp://*",
            min_port=self.internal_worker_port_range[0],
            max_port=self.internal_worker_port_range[1])

        logger.info(
            "Manager listening on {} port for incoming worker connections".
            format(self.worker_port))

        self.task_queues = {'RAW': queue.Queue()}

        self.pending_result_queue = multiprocessing.Queue()

        self.max_queue_size = max_queue_size + self.worker_count
        self.tasks_per_round = 1

        self.heartbeat_period = heartbeat_period
        self.heartbeat_threshold = heartbeat_threshold
        self.poll_period = poll_period

    def create_reg_message(self):
        """ Creates a registration message to identify the worker to the interchange
        """
        msg = {
            'parsl_v':
            PARSL_VERSION,
            'python_v':
            "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor,
                              sys.version_info.micro),
            'os':
            platform.system(),
            'hname':
            platform.node(),
            'dir':
            os.getcwd(),
        }
        b_msg = json.dumps(msg).encode('utf-8')
        return b_msg

    def heartbeat(self):
        """ Send heartbeat to the incoming task queue
        """
        heartbeat = (HEARTBEAT_CODE).to_bytes(4, "little")
        r = self.task_incoming.send(heartbeat)
        logger.debug("Return from heartbeat: {}".format(r))

    def pull_tasks(self, kill_event):
        """ Pull tasks from the incoming tasks 0mq pipe onto the internal
        pending task queue


        While :
            receive results and task requests from the workers
            receive tasks/heartbeats from the Interchange
            match tasks to workers
            if task doesn't have appropriate worker type:
                 launch worker of type.. with LRU or some sort of caching strategy.
            if workers >> tasks:
                 advertize available capacity

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """
        logger.info("[TASK PULL THREAD] starting")
        poller = zmq.Poller()
        poller.register(self.task_incoming, zmq.POLLIN)
        poller.register(self.funcx_task_socket, zmq.POLLIN)

        # Send a registration message
        msg = self.create_reg_message()
        logger.debug("Sending registration message: {}".format(msg))
        self.task_incoming.send(msg)
        last_beat = time.time()
        last_interchange_contact = time.time()
        task_recv_counter = 0
        task_done_counter = 0

        poll_timer = self.poll_period

        # This dict holds the info on the count of workers of various types available.
        worker_map = {'slots': self.worker_count}

        while not kill_event.is_set():
            # Disabling the check on ready_worker_queue disables batching
            logger.debug("[TASK_PULL_THREAD] Loop start")
            pending_task_count = task_recv_counter - task_done_counter
            ready_worker_count = self.worker_map.ready_worker_count()
            logger.debug(
                "[TASK_PULL_THREAD pending_task_count: {} Ready_worker_count: {}"
                .format(pending_task_count, ready_worker_count))

            if time.time() > last_beat + self.heartbeat_period:
                self.heartbeat()
                last_beat = time.time()

            if pending_task_count < self.max_queue_size and ready_worker_count > 0:
                logger.debug("[TASK_PULL_THREAD] Requesting tasks: {}".format(
                    ready_worker_count))
                msg = ((ready_worker_count).to_bytes(4, "little"))
                self.task_incoming.send(msg)

            # Receive results from the workers, if any
            socks = dict(poller.poll(timeout=poll_timer))
            if self.funcx_task_socket in socks and socks[
                    self.funcx_task_socket] == zmq.POLLIN:
                # logger.debug("[FUNCX] There's an incoming result")
                try:
                    w_id, m_type, message = self.funcx_task_socket.recv_multipart(
                    )
                    logger.warning(f"Got registration message")
                    if m_type == b'REGISTER':
                        reg_info = pickle.loads(message)
                        logger.info(
                            "Registration received from worker:{} {}".format(
                                w_id, reg_info))

                        # Increment worker_type count by 1
                        self.worker_map.register_worker(
                            w_id, reg_info['worker_type'])
                        # TODO : HERE

                    elif m_type == b'TASK_RET':
                        logger.info(
                            "Result received from worker:{}".format(w_id))
                        logger.debug(
                            "[TASK_PULL_THREAD] Got result: {}".format(
                                message))
                        self.pending_result_queue.put(message)
                        self.worker_map.put_worker(w_id)
                        task_done_counter += 1

                        # UNCOMMENT to kill workers.
                        # self.kill_worker(1)

                    elif m_type == b'WRKR_DIE':
                        logger.debug(
                            "[KILL] Scrubbing the worker from the map!")
                        self.worker_map.scrub_worker(w_id)

                except Exception as e:
                    logger.warning(
                        "[TASK_PULL_THREAD] FUNCX : caught {}".format(e))

            # Receive task batches from Interchange and forward to workers
            if self.task_incoming in socks and socks[
                    self.task_incoming] == zmq.POLLIN:
                poll_timer = 0
                _, pkl_msg = self.task_incoming.recv_multipart()
                tasks = pickle.loads(pkl_msg)
                last_interchange_contact = time.time()

                if tasks == 'STOP':
                    logger.critical("[TASK_PULL_THREAD] Received stop request")
                    kill_event.set()
                    break

                elif tasks == HEARTBEAT_CODE:
                    logger.debug("Got heartbeat from interchange")

                else:
                    # TODO : Update this to unpack and forward tasks to the appropriate
                    # workers.
                    task_recv_counter += len(tasks)
                    logger.debug(
                        "[TASK_PULL_THREAD] Got tasks: {} of {}".format(
                            [t['task_id'] for t in tasks], task_recv_counter))

                    for task in tasks:
                        # In the FuncX model we forward tasks received directly via a DEALER socket.
                        b_task_id = task['task_id'].encode()

                        # Set default type to raw
                        task_type = task.get('task_type', 'RAW')
                        if task_type not in self.task_queues:
                            self.task_queues[task_type] = queue.Queue()
                        self.task_queues[task_type].put(task)
                        logger.debug(
                            "Task {} pushed to a task queue".format(task))
                        """
                        #logger.debug("[TASK_PULL_THREAD] FuncX attempting send")
                        i = self.funcx_task_socket.send_multipart([b'', b_task_id] + task['buffer'])
                        logger.debug("[TASK_PULL_THREAD] FUNCX Forwarded task: {}".format(task['task_id']))
                        logger.debug("[TASK_PULL_THREAD] forward returned:{}".format(i))
                        """

            else:
                logger.debug("[TASK_PULL_THREAD] No incoming tasks")
                # Limit poll duration to heartbeat_period
                # heartbeat_period is in s vs poll_timer in ms
                if not poll_timer:
                    poll_timer = self.poll_period
                poll_timer = min(self.heartbeat_period * 1000, poll_timer * 2)

                # Only check if no messages were received.
                if time.time(
                ) > last_interchange_contact + self.heartbeat_threshold:
                    logger.critical(
                        "[TASK_PULL_THREAD] Missing contact with interchange beyond heartbeat_threshold"
                    )
                    kill_event.set()
                    logger.critical("[TASK_PULL_THREAD] Exiting")
                    break

            current_worker_map = self.worker_map.get_worker_counts()
            for task_type in current_worker_map:
                if task_type == 'slots':
                    continue

                else:

                    # TODO: TYLER -- in here, make KILL actually kill.

                    available_workers = current_worker_map[task_type]
                    logger.debug("Available workers of type {}: {}".format(
                        task_type, available_workers))
                    for i in range(available_workers):
                        if task_type in self.task_queues and not self.task_queues[
                                task_type].empty():
                            task = self.task_queues[task_type].get()
                            worker_id = self.worker_map.get_worker(task_type)
                            logger.info("Sending task {} to {}".format(
                                task['task_id'], worker_id))
                            to_send = [
                                worker_id,
                                pickle.dumps(task['task_id']), task['buffer']
                            ]
                            self.funcx_task_socket.send_multipart(to_send)
                            logger.debug("Sending done")

    def push_results(self, kill_event, max_result_batch_size=1):
        """ Listens on the pending_result_queue and sends out results via 0mq

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """

        logger.debug("[RESULT_PUSH_THREAD] Starting thread")

        push_poll_period = max(
            10,
            self.poll_period) / 1000  # push_poll_period must be atleast 10 ms
        logger.debug("[RESULT_PUSH_THREAD] push poll period: {}".format(
            push_poll_period))

        last_beat = time.time()
        items = []

        while not kill_event.is_set():
            try:
                r = self.pending_result_queue.get(block=True,
                                                  timeout=push_poll_period)
                items.append(r)
            except queue.Empty:
                pass
            except Exception as e:
                logger.exception(
                    "[RESULT_PUSH_THREAD] Got an exception: {}".format(e))

            # If we have reached poll_period duration or timer has expired, we send results
            if len(items) >= self.max_queue_size or time.time(
            ) > last_beat + push_poll_period:
                last_beat = time.time()
                if items:
                    self.result_outgoing.send_multipart(items)
                    items = []

        logger.critical("[RESULT_PUSH_THREAD] Exiting")

    def launch_worker(self,
                      worker_id=str(random.random()),
                      mode='no_container',
                      container_uri=None,
                      walltime=1):
        """ Launch the appropriate worker

        Parameters
        ----------
        worker_id : str
           Worker identifier string
        mode : str
           Valid options are no_container, singularity
        walltime : int
           Walltime in seconds before we check status

        """
        print("LAUNCH_WORKER is only partially baked")

        debug = ' --debug' if self.debug else ''
        # TODO : This should assign some meaningful worker_id rather than random
        worker_id = ' --worker_id {}'.format(worker_id)

        cmd = (f'funcx-worker {debug}{worker_id} '
               f'-a {self.address} '
               f'-p {self.worker_port} '
               f'--logdir={self.logdir}/{self.uid} ')

        print("Command string : ", cmd)
        if mode == 'no_container':
            modded_cmd = cmd
        elif mode == 'singularity':
            modded_cmd = 'singularity run --writable {container_uri} {cmd}'.format(
                self.container_uri)
        else:
            raise NameError("Invalid container launch mode.")

        stdout = 'STDOUT: READING FAILED'
        stderr = 'STDERR: READING FAILED'
        try:
            proc = subprocess.Popen(cmd,
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE,
                                    shell=True)
            print("Launched proc")

        except Exception as e:
            print(
                "TODO : Got an error in worker launch, got error {}".format(e))

        return proc

    def kill_worker(self, worker_type):
        """
            Kill a worker of a given worker_type. 

            Add a kill message to the task_type queue.

            Assumption : All workers of the same type are uniform, and therefore don't discriminate when killing. 
        """

        # self.dead_workers.add(worker_type)
        # self.available_workers[worker_type] -= 1  # COMMENTED OUT because messes with assigning the KILL task to the worker...
        logger.debug("[KILL] Appending KILL message to worker queue")
        self.task_queues[worker_type].put({
            "task_id": "KILL",
            "buffer": "KILL"
        })
        # self.worker_map.scrub_worker(worker_type)

    def start(self):
        """
        * while True:
            Receive tasks and start appropriate workers
            Push tasks to available workers
            Forward results
        """

        self.task_queues = {}  # k-v: task_type - task_q (PriorityQueue)
        self.worker_capacities = {
        }  # k-v: worker_id - capacity (integer... should only ever be 0 or 1)
        # TODO: Switch ^^^ to FIFO task queue.
        self.task_to_worker_sets = {}  # k-v: task_type - workers (set)

        # Keep track of workers to whom we've sent kill messages
        self.dead_worker_set = set()

        self.workers = [self.launch_worker(worker_id=5)]

        logger.debug("Initial workers launched")
        self._kill_event = threading.Event()
        self._result_pusher_thread = threading.Thread(
            target=self.push_results, args=(self._kill_event, ))
        self._result_pusher_thread.start()

        self.pull_tasks(self._kill_event)
        logger.info("Waiting")

    #----------------------------------------------------------------------------------------------
    # Deprecated
    def _start(self):
        """ Start the worker processes.

        TODO: Move task receiving to a thread
        """
        start = time.time()
        self._kill_event = threading.Event()

        self.procs = {}

        # @Tyler, we should do a `which funcx_worker.py` [note: not entry point, this must be a script]
        # copy that file over the directory '.' and then have the container run with pwd visible
        # as an initial cut, while we resolve possible issues.
        orig_location = os.getcwd()

        if not os.path.isdir("NAMESPACE"):
            os.mkdir("NAMESPACE")

        context = zmq.Context()
        registration_socket = context.socket(zmq.REP)
        self.reg_port = registration_socket.bind_to_random_port("tcp://*",
                                                                min_port=50001,
                                                                max_port=55000)

        for worker_id in range(self.worker_count):

            if self.mode.startswith("singularity"):
                try:
                    os.mkdir("NAMESPACE/{}".format(worker_id))
                    # shutil.copyfile(worker_py_path, "NAMESPACE/{}/funcx_worker.py".format(worker_id))
                except Exception:
                    pass  # Assuming the directory already exists.

            if self.mode == "no_container":
                p = multiprocessing.Process(
                    target=funcx_worker,
                    args=(
                        worker_id,
                        self.uid,
                        "tcp://localhost:{}".format(self.internal_worker_port),
                        "tcp://localhost:{}".format(self.reg_port),
                    ),
                    # DEBUG YADU. MUST SET BACK TO False,
                    kwargs={
                        'no_reuse': False,
                        'debug': self.debug,
                        'logdir': self.logdir
                    })

                p.start()
                self.procs[worker_id] = p

            elif self.mode == "singularity_reuse":

                os.chdir("NAMESPACE/{}".format(worker_id))
                # @Tyler, FuncX worker path needs to be updated to not use the run command in the container.
                # We just want to invoke with "funcx_worker.py" which is found in the $PATH
                sys_cmd = (
                    "singularity run {singularity_img} /usr/local/bin/funcx_worker.py --worker_id {worker_id} "
                    "--pool_id {pool_id} --task_url {task_url} --reg_url {reg_url} "
                    "--logdir {logdir} ")

                sys_cmd = sys_cmd.format(singularity_img=self.container_image,
                                         worker_id=worker_id,
                                         pool_id=self.uid,
                                         task_url="tcp://localhost:{}".format(
                                             self.internal_worker_port),
                                         reg_url="tcp://localhost:{}".format(
                                             self.reg_port),
                                         logdir=self.logdir)

                logger.debug(
                    "Singularity reuse launch cmd: {}".format(sys_cmd))
                proc = subprocess.Popen(sys_cmd, shell=True)
                self.procs[worker_id] = proc

                # Update the command to say something like :
                # while :
                # do
                #     singularity run {singularity_img} funcx_worker.py --no_reuse .....
                # done

                # FuncX worker to accept new --no_reuse flag that breaks the loop after 1 task.
                os.chdir(orig_location)

            elif self.mode == "singularity_single_use":
                # raise Exception("Not supported")
                os.chdir("NAMESPACE/{}".format(worker_id))

                if self.mode.startswith("singularity"):
                    #while True:
                    logger.info("New subprocess loop!")
                    sys_cmd = (
                        "singularity run {singularity_img} /usr/local/bin/funcx_worker.py --no_reuse --worker_id {worker_id} "
                        "--pool_id {pool_id} --task_url {task_url} --reg_url {reg_url} "
                        "--logdir {logdir} ")
                    sys_cmd = sys_cmd.format(
                        singularity_img=self.container_image,
                        worker_id=worker_id,
                        pool_id=self.uid,
                        task_url="tcp://localhost:{}".format(
                            self.internal_worker_port),
                        reg_url="tcp://localhost:{}".format(self.reg_port),
                        logdir=self.logdir)

                    bash_cmd = """ while :
                                   do
                                      {}
                                   done """.format(sys_cmd)
                    logger.debug(
                        "Singularity NO-reuse launch cmd: {}".format(bash_cmd))
                    proc = subprocess.Popen(bash_cmd, shell=True)
                    self.procs[worker_id] = proc
                    os.chdir(orig_location)

        for worker_id in range(self.worker_count):
            msg = registration_socket.recv_pyobj()
            logger.info(
                "Received registration message from worker: {}".format(msg))
            registration_socket.send_pyobj("ACK")

        logger.debug("Manager synced with workers")

        self._task_puller_thread = threading.Thread(target=self.pull_tasks,
                                                    args=(self._kill_event, ))
        self._result_pusher_thread = threading.Thread(
            target=self.push_results, args=(self._kill_event, ))
        self._task_puller_thread.start()
        self._result_pusher_thread.start()

        logger.info("Loop start")

        # TODO : Add mechanism in this loop to stop the worker pool
        # This might need a multiprocessing event to signal back.
        self._kill_event.wait()
        logger.critical(
            "[MAIN] Received kill event, terminating worker processes")

        self._task_puller_thread.join()
        self._result_pusher_thread.join()
        for proc_id in self.procs:
            self.procs[proc_id].terminate()

            if type(self.procs[proc_id]) == "subprocess.Popen":

                poll = p.poll()

                if poll == None:
                    is_alive = False
                else:
                    is_alive = True

                logger.critical("Terminating worker {}:{}".format(
                    self.procs[proc_id], is_alive))
            else:
                logger.critical("Terminating worker {}:{}".format(
                    self.procs[proc_id], self.procs[proc_id].is_alive()))

            self.procs[proc_id].join()
            logger.debug("Worker:{} joined successfully".format(
                self.procs[proc_id]))

        self.task_incoming.close()
        self.result_outgoing.close()
        self.context.term()
        delta = time.time() - start
        logger.info("FuncX Manager ran for {} seconds".format(delta))
        return