def _predict(self, queries: List[Query]) -> List[Prediction]: # Pass queries to model, set null predictions if it errors try: predictions = self._model_inst.predict([x.query for x in queries]) except: logger.error('Error while making predictions:') logger.error(traceback.format_exc()) predictions = [None for x in range(len(queries))] # Transform predictions, adding associated worker & query ID predictions = [ Prediction(x, query.id, self._worker_id) for (x, query) in zip(predictions, queries) ] return predictions
def make_predictions(queries: List[Any], task: str, py_model_class: Type[BaseModel], proposal: Proposal, fine_tune_dataset_path, params: Params) -> List[Any]: inference_cache: InferenceCache = InferenceCache() worker_id = 'local' # print('Queries: {}'.format(queries)) # Worker load best trained model's parameters model_inst = None _print_header('Loading trained model...') model_inst = py_model_class(**proposal.knobs) if task == 'question_answering_covid19' and fine_tune_dataset_path is not None: model_inst.load_parameters(fine_tune_dataset_path) elif task != 'question_answering_covid19': model_inst.load_parameters(params) # Inference worker tells predictor that it is free inference_cache.add_worker(worker_id) # Predictor receives queries queries = [Query(x) for x in queries] # Predictor checks free workers worker_ids = inference_cache.get_workers() assert worker_id in worker_ids # Predictor sends query to worker inference_cache.add_queries_for_worker(worker_id, queries) # Worker receives query queries_at_worker = inference_cache.pop_queries_for_worker( worker_id, len(queries)) assert len(queries_at_worker) == len(queries) # Worker makes prediction on queries _print_header('Making predictions with trained model...') predictions = model_inst.predict([x.query for x in queries_at_worker]) predictions = [ Prediction(x, query.id, worker_id) for (x, query) in zip(predictions, queries_at_worker) ] # Worker sends predictions to predictor inference_cache.add_predictions_for_worker(worker_id, predictions) # Predictor receives predictions predictions_at_predictor = [] for query in queries: prediction = inference_cache.take_prediction_for_worker( worker_id, query.id) assert prediction is not None predictions_at_predictor.append(prediction) ensemble_method = get_ensemble_method(task) print(f'Ensemble method: {ensemble_method}') out_predictions = [] for prediction in predictions_at_predictor: prediction = prediction.prediction _assert_jsonable( prediction, Exception('Each `prediction` should be JSON serializable')) out_prediction = ensemble_method([prediction]) out_predictions.append(out_prediction) print('Predictions: {}'.format(out_predictions)) return (out_predictions, model_inst)