def __init__(self, service_id, meta_store=None): self._service_id = service_id self._meta_store = meta_store or MetaStore() self._redis_host = os.getenv('REDIS_HOST', 'singa_auto_redis') self._redis_port = os.getenv('REDIS_PORT', 6379) self._kafka_host = os.getenv('KAFKA_HOST', 'singa_auto_kafka') self._kafka_port = os.getenv('KAFKA_PORT', 9092) self._ensemble_method: Callable[[List[Any]], Any] = None self._pull_job_info() self._redis_cache = RedisInferenceCache(self._inference_job_id, self._redis_host, self._redis_port) self._kakfa_cache = KafkaInferenceCache() logger.info( f'Initialized predictor for inference job "{self._inference_job_id}"' )
def __init__(self, service_id, worker_id, meta_store=None, param_store=None): self._service_id = service_id self._worker_id = worker_id self._meta_store = meta_store or MetaStore() self._param_store = param_store or FileParamStore() self._redis_host = os.getenv('REDIS_HOST', 'singa_auto_redis') self._redis_port = os.getenv('REDIS_PORT', 6379) self._kafka_host = os.getenv('KAFKA_HOST', 'singa_auto_kafka') self._kafka_port = os.getenv('KAFKA_PORT', 9092) self._batch_size = PREDICT_BATCH_SIZE self._redis_cache: RedisInferenceCache = None self._inference_job_id = None self._model_inst: BaseModel = None self._proposal: Proposal = None self._store_params_id = None self._py_model_class: Type[BaseModel] = None self._kafka_cache = KafkaInferenceCache()
class InferenceWorker(): def __init__(self, service_id, worker_id, meta_store=None, param_store=None): self._service_id = service_id self._worker_id = worker_id self._meta_store = meta_store or MetaStore() self._param_store = param_store or FileParamStore() self._redis_host = os.getenv('REDIS_HOST', 'singa_auto_redis') self._redis_port = os.getenv('REDIS_PORT', 6379) self._kafka_host = os.getenv('KAFKA_HOST', 'singa_auto_kafka') self._kafka_port = os.getenv('KAFKA_PORT', 9092) self._batch_size = PREDICT_BATCH_SIZE self._redis_cache: RedisInferenceCache = None self._inference_job_id = None self._model_inst: BaseModel = None self._proposal: Proposal = None self._store_params_id = None self._py_model_class: Type[BaseModel] = None self._kafka_cache = KafkaInferenceCache() def start(self): self._pull_job_info() self._redis_cache = RedisInferenceCache(self._inference_job_id, self._redis_host, self._redis_port) logger.info( f'Starting worker for inference job "{self._inference_job_id}"...') self._notify_start() # Load trial's model instance self._model_inst = self._load_trial_model() while True: queries = self._fetch_queries() if len(queries) > 0: predictions = self._predict(queries) self._submit_predictions(predictions) else: time.sleep(LOOP_SLEEP_SECS) def stop(self): self._notify_stop() # Run model destroy try: if self._model_inst is not None: self._model_inst.destroy() except: logger.error('Error destroying model:') logger.error(traceback.format_exc()) # Run model class teardown try: if self._py_model_class is not None: self._py_model_class.teardown() except: logger.error('Error tearing down model class:') logger.error(traceback.format_exc()) def _pull_job_info(self): service_id = self._service_id logger.info('Reading job info from meta store...') with self._meta_store: worker = self._meta_store.get_inference_job_worker(service_id) if worker is None: raise InvalidWorkerError( 'No such worker "{}"'.format(service_id)) inference_job = self._meta_store.get_inference_job( worker.inference_job_id) if inference_job is None: raise InvalidWorkerError( 'No such inference job with ID "{}"'.format( worker.inference_job_id)) if inference_job.model_id: model = self._meta_store.get_model(inference_job.model_id) logger.info(f'Using checkpoint of the model "{model.name}"...') self._proposal = Proposal.from_jsonable({ "trial_no": 1, "knobs": {} }) self._store_params_id = model.checkpoint_id else: trial = self._meta_store.get_trial(worker.trial_id) if trial is None or trial.store_params_id is None: # Must have model saved raise InvalidTrialError( 'No saved trial with ID "{}"'.format(worker.trial_id)) logger.info(f'Using trial "{trial.id}"...') model = self._meta_store.get_model(trial.model_id) if model is None: raise InvalidTrialError( 'No such model with ID "{}"'.format(trial.model_id)) logger.info(f'Using model "{model.name}"...') self._proposal = Proposal.from_jsonable(trial.proposal) self._store_params_id = trial.store_params_id self._inference_job_id = inference_job.id self._py_model_class = load_model_class(model.model_file_bytes, model.model_class) def _load_trial_model(self): logger.info('Loading saved model parameters from store...') params = self._param_store.load(self._store_params_id) logger.info('Loading trial\'s trained model...') model_inst = self._py_model_class(**self._proposal.knobs) model_inst.load_parameters(params) return model_inst def _notify_start(self): superadmin_client().send_event('inference_job_worker_started', inference_job_id=self._inference_job_id) self._redis_cache.add_worker(self._worker_id) def _fetch_queries(self) -> List[Query]: queries = self._kafka_cache.pop_queries_for_worker( self._worker_id, self._batch_size) return queries def _predict(self, queries: List[Query]) -> List[Prediction]: # Pass queries to model, set null predictions if it errors try: predictions = self._model_inst.predict([x.query for x in queries]) except: logger.error('Error while making predictions:') logger.error(traceback.format_exc()) predictions = [None for x in range(len(queries))] # Transform predictions, adding associated worker & query ID predictions = [ Prediction(x, query.id, self._worker_id) for (x, query) in zip(predictions, queries) ] return predictions def _submit_predictions(self, predictions: List[Prediction]): self._kafka_cache.add_predictions_for_worker(self._worker_id, predictions) def _notify_stop(self): self._redis_cache.delete_worker(self._worker_id) superadmin_client().send_event('inference_job_worker_stopped', inference_job_id=self._inference_job_id)
class Predictor: def __init__(self, service_id, meta_store=None): self._service_id = service_id self._meta_store = meta_store or MetaStore() self._redis_host = os.getenv('REDIS_HOST', 'singa_auto_redis') self._redis_port = os.getenv('REDIS_PORT', 6379) self._kafka_host = os.getenv('KAFKA_HOST', 'singa_auto_kafka') self._kafka_port = os.getenv('KAFKA_PORT', 9092) self._ensemble_method: Callable[[List[Any]], Any] = None self._pull_job_info() self._redis_cache = RedisInferenceCache(self._inference_job_id, self._redis_host, self._redis_port) self._kakfa_cache = KafkaInferenceCache() logger.info( f'Initialized predictor for inference job "{self._inference_job_id}"' ) # Only a single thread should run this def start(self): self._notify_start() def predict(self, queries): worker_predictions_list = self._get_predictions_from_workers(queries) predictions = self._combine_worker_predictions(worker_predictions_list) return predictions # Only a single thread should run this def stop(self): self._notify_stop() # Clear caches for inference job try: self._redis_cache.clear_all() except: logger.error('Error clearing inference cache:') logger.error(traceback.format_exc()) def _pull_job_info(self): service_id = self._service_id logger.info('Reading job info from meta store...') with self._meta_store: inference_job = self._meta_store.get_inference_job_by_predictor( service_id) if inference_job is None: raise InvalidInferenceJobError( 'No inference job associated with predictor "{}"'.format( service_id)) if inference_job.train_job_id is None and inference_job.model_id is None: raise InvalidInferenceJobError( 'No train job or checkpoint found with inference ID "{}"'. format(inference_job.id)) if inference_job.train_job_id is not None: train_job = self._meta_store.get_train_job( inference_job.train_job_id) if train_job is None: raise InvalidInferenceJobError( 'No such train job with ID "{}"'.format( inference_job.train_job_id)) self._ensemble_method = get_ensemble_method(train_job.task) if inference_job.model_id is not None: self._ensemble_method = get_ensemble_method() self._inference_job_id = inference_job.id logger.info(f'Using ensemble method: {self._ensemble_method}...') def _notify_start(self): superadmin_client().send_event('predictor_started', inference_job_id=self._inference_job_id) def _get_predictions_from_workers( self, queries: List[Any]) -> List[List[Prediction]]: queries = [Query(x) for x in queries] # Wait for at least 1 free worker worker_ids = [] while len(worker_ids) == 0: worker_ids = self._redis_cache.get_workers() # For each worker, send queries to worker pending_queries = set() # {(query_id, worker_id)} for worker_id in worker_ids: self._kakfa_cache.add_queries_for_worker(worker_id, queries) # self._redis_cache.add_queries_for_worker(worker_id, queries) pending_queries.update([(x.id, worker_id) for x in queries]) # Wait for all predictions to be made query_id_to_predictions = defaultdict( list) # { <query_id>: [prediction] } while len(pending_queries) > 0: # For every pending query to worker for (query_id, worker_id) in list(pending_queries): # Check cache prediction = self._kakfa_cache.take_prediction_for_worker( worker_id, query_id) # prediction = self._redis_cache.take_prediction_for_worker(worker_id, query_id) if prediction is None: continue # Record prediction & mark as not pending query_id_to_predictions[query_id].append(prediction) pending_queries.remove((query_id, worker_id)) time.sleep(PREDICT_LOOP_SLEEP_SECS) # Reorganize predictions worker_predictions_list = [] for query in queries: worker_predictions = query_id_to_predictions[query.id] worker_predictions_list.append(worker_predictions) return worker_predictions_list def _combine_worker_predictions( self, worker_predictions_list: List[List[Prediction]]) -> List[Any]: # Ensemble predictions for each query predictions = [] for worker_predictions in worker_predictions_list: # Transform predictions & remove all null predictions worker_predictions = [ x.prediction for x in worker_predictions if x.prediction is not None ] # Do ensembling prediction = self._ensemble_method(worker_predictions) predictions.append(prediction) return predictions def _notify_stop(self): superadmin_client().send_event('predictor_stopped', inference_job_id=self._inference_job_id)