class CentralScheduler(object): def __init__(self, endpoints, strategy='round-robin', runtime_predictor='rolling-average', last_n=3, train_every=1, log_level='INFO', import_model_file=None, transfer_model_file=None, sync_level='exists', max_backups=0, backup_delay_threshold=2.0, *args, **kwargs): self._fxc = FuncXClient(*args, **kwargs) # Initialize a transfer client self._transfer_manger = TransferManager(endpoints=endpoints, sync_level=sync_level, log_level=log_level) # Info about FuncX endpoints we can execute on self._endpoints = endpoints self._dead_endpoints = set() self.last_result_time = defaultdict(float) self.temperature = defaultdict(lambda: 'WARM') self._imports = defaultdict(list) self._imports_required = defaultdict(list) # Track which endpoints a function can't run on self._blocked = defaultdict(set) # Track pending tasks # We will provide the client our own task ids, since we may submit the # same task multiple times to the FuncX service, and sometimes we may # wait to submit a task to FuncX (e.g., wait for a data transfer). self._task_id_translation = {} self._pending = {} self._pending_by_endpoint = defaultdict(set) self._task_info = {} # List of endpoints a (virtual) task was scheduled to self._endpoints_sent_to = defaultdict(list) self.max_backups = max_backups self.backup_delay_threshold = backup_delay_threshold self._latest_status = {} self._last_task_ETA = defaultdict(float) # Maximum ETA, if any, of a task which we allow to be scheduled on an # endpoint. This is to prevent backfill tasks to be longer than the # estimated time for when a pending data transfer will finish. self._transfer_ETAs = defaultdict(dict) # Estimated error in the pending-task time of an endpoint. # Updated every time a task result is received from an endpoint. self._queue_error = defaultdict(float) # Set logging levels logger.setLevel(log_level) self.execution_log = [] # Intialize serializer self.fx_serializer = FuncXSerializer() self.fx_serializer.use_custom('03\n', 'code') # Initialize runtime predictor self.runtime = init_runtime_predictor(runtime_predictor, endpoints=endpoints, last_n=last_n, train_every=train_every) logger.info(f"Runtime predictor using strategy {self.runtime}") # Initialize transfer-time predictor self.transfer_time = TransferPredictor(endpoints=endpoints, train_every=train_every, state_file=transfer_model_file) # Initialize import-time predictor self.import_predictor = ImportPredictor(endpoints=endpoints, state_file=import_model_file) # Initialize scheduling strategy self.strategy = init_strategy(strategy, endpoints=endpoints, runtime_predictor=self.runtime, queue_predictor=self.queue_delay, cold_start_predictor=self.cold_start, transfer_predictor=self.transfer_time) logger.info(f"Scheduler using strategy {self.strategy}") # Start thread to check on endpoints regularly self._endpoint_watchdog = Thread(target=self._check_endpoints) self._endpoint_watchdog.start() # Start thread to monitor tasks and send tasks to FuncX service self._scheduled_tasks = Queue() self._task_watchdog_sleep = 0.15 self._task_watchdog = Thread(target=self._monitor_tasks) self._task_watchdog.start() def block(self, func, endpoint): if endpoint not in self._endpoints: logger.error('Cannot block unknown endpoint {}'.format(endpoint)) return { 'status': 'Failed', 'reason': 'Unknown endpoint {}'.format(endpoint) } elif len(self._blocked[func]) == len(self._endpoints) - 1: logger.error( 'Cannot block last remaining endpoint {}'.format(endpoint)) return { 'status': 'Failed', 'reason': 'Cannot block all endpoints for {}'.format(func) } else: logger.info('Blocking endpoint {} for function {}'.format( endpoint_name(endpoint), func)) self._blocked[func].add(endpoint) return {'status': 'Success'} def register_imports(self, func, imports): logger.info('Registered function {} with imports {}'.format( func, imports)) self._imports_required[func] = imports def batch_submit(self, tasks, headers): # TODO: smarter scheduling for batch submissions task_ids = [] endpoints = [] for func, payload in tasks: _, ser_kwargs = self.fx_serializer.unpack_buffers(payload) kwargs = self.fx_serializer.deserialize(ser_kwargs) files = kwargs['_globus_files'] task_id, endpoint = self._schedule_task(func=func, payload=payload, headers=headers, files=files) task_ids.append(task_id) endpoints.append(endpoint) return task_ids, endpoints def _schedule_task(self, func, payload, headers, files, task_id=None): # If this is the first time scheduling this task_id # (i.e., non-backup task), record the necessary metadata if task_id is None: # Create (fake) task id to return to client task_id = str(uuid.uuid4()) # Store task information self._task_id_translation[task_id] = set() # Information required to schedule the task, now and in the future info = { 'function_id': func, 'payload': payload, 'headers': headers, 'files': files, 'time_requested': time.time() } self._task_info[task_id] = info # TODO: do not choose a dead endpoint (reliably) # exclude = self._blocked[func] | self._dead_endpoints | set(self._endpoints_sent_to[task_id]) # noqa if len(self._dead_endpoints) > 0: logger.warn('{} endpoints seem dead. Hope they still work!'.format( len(self._dead_endpoints))) exclude = self._blocked[func] | set(self._endpoints_sent_to[task_id]) choice = self.strategy.choose_endpoint( func, payload=payload, files=files, exclude=exclude, transfer_ETAs=self._transfer_ETAs) # noqa endpoint = choice['endpoint'] logger.info('Choosing endpoint {} for func {}, task id {}'.format( endpoint_name(endpoint), func, task_id)) choice['ETA'] = self.strategy.predict_ETA(func, endpoint, payload, files=files) # Start Globus transfer of required files, if any if len(files) > 0: transfer_num = self._transfer_manger.transfer( files, endpoint, task_id) if transfer_num is not None: transfer_ETA = time.time() + self.transfer_time( files, endpoint) self._transfer_ETAs[endpoint][transfer_num] = transfer_ETA else: transfer_num = None # Record endpoint ETA for queue-delay prediction here, # since task will be immediately scheduled self._last_task_ETA[endpoint] = choice['ETA'] # If a cold endpoint is being started, mark it as no longer cold, # so that subsequent launch-time predictions are correct (i.e., 0) if self.temperature[endpoint] == 'COLD': self.temperature[endpoint] = 'WARMING' logger.info( 'A cold endpoint {} was chosen; marked as warming.'.format( endpoint_name(endpoint))) # Schedule task for sending to FuncX self._endpoints_sent_to[task_id].append(endpoint) self._scheduled_tasks.put((task_id, endpoint, transfer_num)) return task_id, endpoint def translate_task_id(self, task_id): return self._task_id_translation[task_id] def log_status(self, real_task_id, data): if real_task_id not in self._pending: logger.warn('Ignoring unknown task id {}'.format(real_task_id)) return task_id = self._pending[real_task_id]['task_id'] func = self._pending[real_task_id]['function_id'] endpoint = self._pending[real_task_id]['endpoint_id'] # Don't overwrite latest status if it is a result/exception if task_id not in self._latest_status or \ self._latest_status[task_id].get('status') == 'PENDING': self._latest_status[task_id] = data if 'result' in data: result = self.fx_serializer.deserialize(data['result']) runtime = result['runtime'] name = endpoint_name(endpoint) logger.info('Got result from {} for task {} with time {}'.format( name, real_task_id, runtime)) self.runtime.update(self._pending[real_task_id], runtime) self._pending[real_task_id]['runtime'] = runtime self._record_completed(real_task_id) self.last_result_time[endpoint] = time.time() self._imports[endpoint] = result['imports'] elif 'exception' in data: exception = self.fx_serializer.deserialize(data['exception']) try: exception.reraise() except Exception as e: logger.error('Got exception on task {}: {}'.format( real_task_id, e)) exc_type, _, _ = sys.exc_info() if exc_type in BLOCK_ERRORS: self.block(func, endpoint) self._record_completed(real_task_id) self.last_result_time[endpoint] = time.time() elif 'status' in data and data['status'] == 'PENDING': pass else: logger.error('Unexpected status message: {}'.format(data)) def get_status(self, task_id): if task_id not in self._task_id_translation: logger.warn('Unknown client task id {}'.format(task_id)) elif len(self._task_id_translation[task_id]) == 0: return {'status': 'PENDING'} # Task has not been scheduled yet elif task_id not in self._latest_status: return {'status': 'PENDING'} # Status has not been queried yet else: return self._latest_status[task_id] def queue_delay(self, endpoint): # Otherwise, queue delay is the ETA of most recent task, # plus the estimated error in the ETA prediction. # Note that if there are no pending tasks on endpoint, no queue delay. # This is implicit since, in this case, both summands will be 0. delay = self._last_task_ETA[endpoint] + self._queue_error[endpoint] return max(delay, time.time()) def _record_completed(self, real_task_id): info = self._pending[real_task_id] endpoint = info['endpoint_id'] # If this is the last pending task on this endpoint, reset ETA offset if len(self._pending_by_endpoint[endpoint]) == 1: self._last_task_ETA[endpoint] = 0.0 self._queue_error[endpoint] = 0.0 else: prediction_error = time.time() - self._pending[real_task_id]['ETA'] self._queue_error[endpoint] = prediction_error # print(colored(f'Prediction error {prediction_error}', 'red')) info['ATA'] = time.time() del info['headers'] self.execution_log.append(info) logger.info( 'Task exec time: expected = {:.3f}, actual = {:.3f}'.format( info['ETA'] - info['time_sent'], time.time() - info['time_sent'])) # logger.info(f'ETA_offset = {self._queue_error[endpoint]:.3f}') # Stop tracking this task del self._pending[real_task_id] self._pending_by_endpoint[endpoint].remove(real_task_id) if info['task_id'] in self._task_info: del self._task_info[info['task_id']] def cold_start(self, endpoint, func): # If endpoint is warm, there is no launch time if self.temperature[endpoint] != 'COLD': launch_time = 0.0 # Otherwise, return the launch time in the endpoint config elif 'launch_time' in self._endpoints[endpoint]: launch_time = self._endpoints[endpoint]['launch_time'] else: logger.warn( 'Endpoint {} should always be warm, but is cold'.format( endpoint_name(endpoint))) launch_time = 0.0 # Time to import dependencies import_time = 0.0 for pkg in self._imports_required[func]: if pkg not in self._imports[endpoint]: logger.debug( 'Cold-start has import time for pkg {} on {}'.format( pkg, endpoint_name(endpoint))) import_time += self.import_predictor(pkg, endpoint) return launch_time + import_time def _monitor_tasks(self): logger.info('Starting task-watchdog thread') scheduled = {} while True: time.sleep(self._task_watchdog_sleep) # Get newly scheduled tasks while True: try: task_id, end, num = self._scheduled_tasks.get_nowait() if task_id not in self._task_info: logger.warn( 'Task id {} scheduled but no info found'.format( task_id)) continue info = self._task_info[task_id] scheduled[task_id] = dict(info) # Create new copy of info scheduled[task_id]['task_id'] = task_id scheduled[task_id]['endpoint_id'] = end scheduled[task_id]['transfer_num'] = num except Empty: break # Filter out all tasks whose data transfer has not been completed ready_to_send = set() for task_id, info in scheduled.items(): transfer_num = info['transfer_num'] if transfer_num is None: ready_to_send.add(task_id) info['transfer_time'] = 0.0 elif self._transfer_manger.is_complete(transfer_num): ready_to_send.add(task_id) del self._transfer_ETAs[info['endpoint_id']][transfer_num] info[ 'transfer_time'] = self._transfer_manger.get_transfer_time( transfer_num) # noqa else: # This task cannot be scheduled yet continue if len(ready_to_send) == 0: logger.debug('No new tasks to send. Task watchdog sleeping...') continue # TODO: different clients send different headers. change eventually headers = list(scheduled.values())[0]['headers'] logger.info('Scheduling a batch of {} tasks'.format( len(ready_to_send))) # Submit all ready tasks to FuncX data = {'tasks': []} for task_id in ready_to_send: info = scheduled[task_id] submit_info = (info['function_id'], info['endpoint_id'], info['payload']) data['tasks'].append(submit_info) res_str = requests.post(f'{FUNCX_API}/submit', headers=headers, data=json.dumps(data)) try: res = res_str.json() except ValueError: logger.error(f'Could not parse JSON from {res_str.text}') continue if res['status'] != 'Success': logger.error( 'Could not send tasks to FuncX. Got response: {}'.format( res)) continue # Update task info with submission info for task_id, real_task_id in zip(ready_to_send, res['task_uuids']): info = scheduled[task_id] # This ETA calculation does not take into account transfer time # since, at this point, the transfer has already completed. info['ETA'] = self.strategy.predict_ETA( info['function_id'], info['endpoint_id'], info['payload']) # Record if this ETA prediction is "reliable". If it is not # (e.g., when we have not learned about this (func, ep) pair), # backup tasks will not be sent for this task if it is delayed. info['is_ETA_reliable'] = self.runtime.has_learned( info['function_id'], info['endpoint_id']) info['time_sent'] = time.time() endpoint = info['endpoint_id'] self._task_id_translation[task_id].add(real_task_id) self._pending[real_task_id] = info self._pending_by_endpoint[endpoint].add(real_task_id) # Record endpoint ETA for queue-delay prediction self._last_task_ETA[endpoint] = info['ETA'] logger.info( 'Sent task id {} to {} with real task id {}'.format( task_id, endpoint_name(endpoint), real_task_id)) # Stop tracking all newly sent tasks for task_id in ready_to_send: del scheduled[task_id] def _check_endpoints(self): logger.info('Starting endpoint-watchdog thread') while True: for end in self._endpoints.keys(): statuses = self._fxc.get_endpoint_status(end) if len(statuses) == 0: logger.warn( 'Endpoint {} does not have any statuses'.format( endpoint_name(end))) else: status = statuses[0] # Most recent endpoint status # Mark endpoint as dead/alive based on heartbeat's age # Heartbeats are delayed when an endpoint is executing # tasks, so take into account last execution too age = time.time() - max(status['timestamp'], self.last_result_time[end]) is_dead = end in self._dead_endpoints if not is_dead and age > HEARTBEAT_THRESHOLD: self._dead_endpoints.add(end) logger.warn( 'Endpoint {} seems to have died! ' 'Last heartbeat was {:.2f} seconds ago.'.format( endpoint_name(end), age)) elif is_dead and age <= HEARTBEAT_THRESHOLD: self._dead_endpoints.remove(end) logger.warn( 'Endpoint {} is back alive! ' 'Last heartbeat was {:.2f} seconds ago.'.format( endpoint_name(end), age)) # Mark endpoint as "cold" or "warm" depending on if it # has active managers (nodes) allocated to it if self.temperature[end] == 'WARM' \ and status['active_managers'] == 0: self.temperature[end] = 'COLD' logger.info('Endpoint {} is cold!'.format( endpoint_name(end))) elif self.temperature[end] != 'WARM' \ and status['active_managers'] > 0: self.temperature[end] = 'WARM' logger.info('Endpoint {} is warm again!'.format( endpoint_name(end))) # Send backup tasks if needed self._send_backups_if_needed() # Sleep before checking statuses again time.sleep(5) def _send_backups_if_needed(self): # Get all tasks which have not been completed yet and still have a # pending (real) task on a dead endpoint task_ids = { self._pending[real_task_id]['task_id'] for endpoint in self._dead_endpoints for real_task_id in self._pending_by_endpoint[endpoint] if self._pending[real_task_id]['task_id'] in self._task_info } # Get all tasks for which we had ETA-predictions but haven't # been completed even past their ETA for real_task_id, info in self._pending.items(): # If the predicted ETA wasn't reliable, don't send backups if not info['is_ETA_reliable']: continue expected = info['ETA'] - info['time_sent'] elapsed = time.time() - info['time_sent'] if elapsed / expected > self.backup_delay_threshold: task_ids.add(info['task_id']) for task_id in task_ids: if len(self._endpoints_sent_to[task_id]) > self.max_backups: logger.debug(f'Skipping sending new backup task for {task_id}') else: logger.info(f'Sending new backup task for {task_id}') info = self._task_info[task_id] self._schedule_task(info['function_id'], info['payload'], info['headers'], info['files'], task_id)
""" res = requests.post('http://localhost:8000/deserialize', json=payload) print(res.status_code) return res.json() def serialize(payload): """Try to serialize some input and return the result. """ res = requests.post('http://localhost:8000/serialize', json=payload) print(res.status_code) return res.json() if __name__ == "__main__" : payload = {'name': 'bob'} print(f'Input: {payload}') x = serialize(payload) print(f'Serialized: {x}') # Trim off kwargs (part 2 of the buffer) fx_serializer = FuncXSerializer() res = fx_serializer.unpack_buffers(x) print(res) y = deserialize(res[0]) print(f'Deserialized: {y}') print('now break things') z = deserialize(res) print(z)