Example #1
0
    def __init__(self, service_id, meta_store=None):
        self._service_id = service_id
        self._meta_store = meta_store or MetaStore()
        self._redis_host = os.getenv('REDIS_HOST', 'singa_auto_redis')
        self._redis_port = os.getenv('REDIS_PORT', 6379)
        self._kafka_host = os.getenv('KAFKA_HOST', 'singa_auto_kafka')
        self._kafka_port = os.getenv('KAFKA_PORT', 9092)
        self._ensemble_method: Callable[[List[Any]], Any] = None

        self._pull_job_info()
        self._redis_cache = RedisInferenceCache(self._inference_job_id,
                                                self._redis_host,
                                                self._redis_port)
        self._kakfa_cache = KafkaInferenceCache()
        logger.info(
            f'Initialized predictor for inference job "{self._inference_job_id}"'
        )
Example #2
0
 def __init__(self,
              service_id,
              worker_id,
              meta_store=None,
              param_store=None):
     self._service_id = service_id
     self._worker_id = worker_id
     self._meta_store = meta_store or MetaStore()
     self._param_store = param_store or FileParamStore()
     self._redis_host = os.getenv('REDIS_HOST', 'singa_auto_redis')
     self._redis_port = os.getenv('REDIS_PORT', 6379)
     self._kafka_host = os.getenv('KAFKA_HOST', 'singa_auto_kafka')
     self._kafka_port = os.getenv('KAFKA_PORT', 9092)
     self._batch_size = PREDICT_BATCH_SIZE
     self._redis_cache: RedisInferenceCache = None
     self._inference_job_id = None
     self._model_inst: BaseModel = None
     self._proposal: Proposal = None
     self._store_params_id = None
     self._py_model_class: Type[BaseModel] = None
     self._kafka_cache = KafkaInferenceCache()
Example #3
0
class InferenceWorker():
    def __init__(self,
                 service_id,
                 worker_id,
                 meta_store=None,
                 param_store=None):
        self._service_id = service_id
        self._worker_id = worker_id
        self._meta_store = meta_store or MetaStore()
        self._param_store = param_store or FileParamStore()
        self._redis_host = os.getenv('REDIS_HOST', 'singa_auto_redis')
        self._redis_port = os.getenv('REDIS_PORT', 6379)
        self._kafka_host = os.getenv('KAFKA_HOST', 'singa_auto_kafka')
        self._kafka_port = os.getenv('KAFKA_PORT', 9092)
        self._batch_size = PREDICT_BATCH_SIZE
        self._redis_cache: RedisInferenceCache = None
        self._inference_job_id = None
        self._model_inst: BaseModel = None
        self._proposal: Proposal = None
        self._store_params_id = None
        self._py_model_class: Type[BaseModel] = None
        self._kafka_cache = KafkaInferenceCache()

    def start(self):
        self._pull_job_info()
        self._redis_cache = RedisInferenceCache(self._inference_job_id,
                                                self._redis_host,
                                                self._redis_port)

        logger.info(
            f'Starting worker for inference job "{self._inference_job_id}"...')

        self._notify_start()

        # Load trial's model instance
        self._model_inst = self._load_trial_model()

        while True:
            queries = self._fetch_queries()
            if len(queries) > 0:
                predictions = self._predict(queries)
                self._submit_predictions(predictions)
            else:
                time.sleep(LOOP_SLEEP_SECS)

    def stop(self):
        self._notify_stop()

        # Run model destroy
        try:
            if self._model_inst is not None:
                self._model_inst.destroy()
        except:
            logger.error('Error destroying model:')
            logger.error(traceback.format_exc())

        # Run model class teardown
        try:
            if self._py_model_class is not None:
                self._py_model_class.teardown()
        except:
            logger.error('Error tearing down model class:')
            logger.error(traceback.format_exc())

    def _pull_job_info(self):
        service_id = self._service_id

        logger.info('Reading job info from meta store...')
        with self._meta_store:
            worker = self._meta_store.get_inference_job_worker(service_id)
            if worker is None:
                raise InvalidWorkerError(
                    'No such worker "{}"'.format(service_id))

            inference_job = self._meta_store.get_inference_job(
                worker.inference_job_id)
            if inference_job is None:
                raise InvalidWorkerError(
                    'No such inference job with ID "{}"'.format(
                        worker.inference_job_id))
            if inference_job.model_id:
                model = self._meta_store.get_model(inference_job.model_id)
                logger.info(f'Using checkpoint of the model "{model.name}"...')

                self._proposal = Proposal.from_jsonable({
                    "trial_no": 1,
                    "knobs": {}
                })
                self._store_params_id = model.checkpoint_id
            else:

                trial = self._meta_store.get_trial(worker.trial_id)
                if trial is None or trial.store_params_id is None:  # Must have model saved
                    raise InvalidTrialError(
                        'No saved trial with ID "{}"'.format(worker.trial_id))
                logger.info(f'Using trial "{trial.id}"...')

                model = self._meta_store.get_model(trial.model_id)
                if model is None:
                    raise InvalidTrialError(
                        'No such model with ID "{}"'.format(trial.model_id))
                logger.info(f'Using model "{model.name}"...')

                self._proposal = Proposal.from_jsonable(trial.proposal)
                self._store_params_id = trial.store_params_id

            self._inference_job_id = inference_job.id
            self._py_model_class = load_model_class(model.model_file_bytes,
                                                    model.model_class)

    def _load_trial_model(self):
        logger.info('Loading saved model parameters from store...')
        params = self._param_store.load(self._store_params_id)

        logger.info('Loading trial\'s trained model...')
        model_inst = self._py_model_class(**self._proposal.knobs)
        model_inst.load_parameters(params)

        return model_inst

    def _notify_start(self):
        superadmin_client().send_event('inference_job_worker_started',
                                       inference_job_id=self._inference_job_id)
        self._redis_cache.add_worker(self._worker_id)

    def _fetch_queries(self) -> List[Query]:
        queries = self._kafka_cache.pop_queries_for_worker(
            self._worker_id, self._batch_size)
        return queries

    def _predict(self, queries: List[Query]) -> List[Prediction]:
        # Pass queries to model, set null predictions if it errors
        try:
            predictions = self._model_inst.predict([x.query for x in queries])
        except:
            logger.error('Error while making predictions:')
            logger.error(traceback.format_exc())
            predictions = [None for x in range(len(queries))]

        # Transform predictions, adding associated worker & query ID
        predictions = [
            Prediction(x, query.id, self._worker_id)
            for (x, query) in zip(predictions, queries)
        ]

        return predictions

    def _submit_predictions(self, predictions: List[Prediction]):
        self._kafka_cache.add_predictions_for_worker(self._worker_id,
                                                     predictions)

    def _notify_stop(self):
        self._redis_cache.delete_worker(self._worker_id)
        superadmin_client().send_event('inference_job_worker_stopped',
                                       inference_job_id=self._inference_job_id)
Example #4
0
class Predictor:
    def __init__(self, service_id, meta_store=None):
        self._service_id = service_id
        self._meta_store = meta_store or MetaStore()
        self._redis_host = os.getenv('REDIS_HOST', 'singa_auto_redis')
        self._redis_port = os.getenv('REDIS_PORT', 6379)
        self._kafka_host = os.getenv('KAFKA_HOST', 'singa_auto_kafka')
        self._kafka_port = os.getenv('KAFKA_PORT', 9092)
        self._ensemble_method: Callable[[List[Any]], Any] = None

        self._pull_job_info()
        self._redis_cache = RedisInferenceCache(self._inference_job_id,
                                                self._redis_host,
                                                self._redis_port)
        self._kakfa_cache = KafkaInferenceCache()
        logger.info(
            f'Initialized predictor for inference job "{self._inference_job_id}"'
        )

    # Only a single thread should run this
    def start(self):
        self._notify_start()

    def predict(self, queries):
        worker_predictions_list = self._get_predictions_from_workers(queries)
        predictions = self._combine_worker_predictions(worker_predictions_list)
        return predictions

    # Only a single thread should run this
    def stop(self):
        self._notify_stop()

        # Clear caches for inference job
        try:
            self._redis_cache.clear_all()
        except:
            logger.error('Error clearing inference cache:')
            logger.error(traceback.format_exc())

    def _pull_job_info(self):
        service_id = self._service_id

        logger.info('Reading job info from meta store...')
        with self._meta_store:
            inference_job = self._meta_store.get_inference_job_by_predictor(
                service_id)
            if inference_job is None:
                raise InvalidInferenceJobError(
                    'No inference job associated with predictor "{}"'.format(
                        service_id))

            if inference_job.train_job_id is None and inference_job.model_id is None:
                raise InvalidInferenceJobError(
                    'No train job or checkpoint found with inference ID "{}"'.
                    format(inference_job.id))

            if inference_job.train_job_id is not None:
                train_job = self._meta_store.get_train_job(
                    inference_job.train_job_id)
                if train_job is None:
                    raise InvalidInferenceJobError(
                        'No such train job with ID "{}"'.format(
                            inference_job.train_job_id))
                self._ensemble_method = get_ensemble_method(train_job.task)
            if inference_job.model_id is not None:
                self._ensemble_method = get_ensemble_method()

            self._inference_job_id = inference_job.id

            logger.info(f'Using ensemble method: {self._ensemble_method}...')

    def _notify_start(self):
        superadmin_client().send_event('predictor_started',
                                       inference_job_id=self._inference_job_id)

    def _get_predictions_from_workers(
            self, queries: List[Any]) -> List[List[Prediction]]:
        queries = [Query(x) for x in queries]

        # Wait for at least 1 free worker
        worker_ids = []
        while len(worker_ids) == 0:
            worker_ids = self._redis_cache.get_workers()

        # For each worker, send queries to worker
        pending_queries = set()  # {(query_id, worker_id)}
        for worker_id in worker_ids:
            self._kakfa_cache.add_queries_for_worker(worker_id, queries)
            # self._redis_cache.add_queries_for_worker(worker_id, queries)
            pending_queries.update([(x.id, worker_id) for x in queries])

        # Wait for all predictions to be made
        query_id_to_predictions = defaultdict(
            list)  # { <query_id>: [prediction] }
        while len(pending_queries) > 0:
            # For every pending query to worker
            for (query_id, worker_id) in list(pending_queries):
                # Check cache
                prediction = self._kakfa_cache.take_prediction_for_worker(
                    worker_id, query_id)
                # prediction = self._redis_cache.take_prediction_for_worker(worker_id, query_id)
                if prediction is None:
                    continue

                # Record prediction & mark as not pending
                query_id_to_predictions[query_id].append(prediction)
                pending_queries.remove((query_id, worker_id))

            time.sleep(PREDICT_LOOP_SLEEP_SECS)

        # Reorganize predictions
        worker_predictions_list = []
        for query in queries:
            worker_predictions = query_id_to_predictions[query.id]
            worker_predictions_list.append(worker_predictions)

        return worker_predictions_list

    def _combine_worker_predictions(
            self,
            worker_predictions_list: List[List[Prediction]]) -> List[Any]:
        # Ensemble predictions for each query
        predictions = []
        for worker_predictions in worker_predictions_list:
            # Transform predictions & remove all null predictions
            worker_predictions = [
                x.prediction for x in worker_predictions
                if x.prediction is not None
            ]

            # Do ensembling
            prediction = self._ensemble_method(worker_predictions)
            predictions.append(prediction)

        return predictions

    def _notify_stop(self):
        superadmin_client().send_event('predictor_stopped',
                                       inference_job_id=self._inference_job_id)