Ejemplo n.º 1
0
    def predict(self, X=None, return_logit=False, n_predictions=None):
        """ Compute prediction of an answer to a question

        Parameters
        ----------
        X: str or list of strings
            Sample (question) or list of samples to perform a prediction on

        return_logit: boolean
            Whether to return logit of best answer or not. Default: False

        Returns
        -------
        If X is str
        prediction: tuple (answer, title, paragraph)

        If X is list os strings
        predictions: list of tuples (answer, title, paragraph)

        If return_logits is True, each prediction tuple will have the following
        structure: (answer, title, paragraph, best logit)

        """
        if isinstance(X, str):
            closest_docs_indices = self.retriever.predict(
                X, metadata=self.metadata)
            squad_examples = generate_squad_examples(
                question=X,
                closest_docs_indices=closest_docs_indices,
                metadata=self.metadata,
                retrieve_by_doc=self.retrieve_by_doc,
            )
            examples, features = self.processor_predict.fit_transform(
                X=squad_examples)
            prediction = self.reader.predict((examples, features),
                                             return_logit, n_predictions)
            return prediction

        elif isinstance(X, list):
            predictions = []
            for query in X:
                closest_docs_indices = self.retriever.predict(
                    query, metadata=self.metadata)
                squad_examples = generate_squad_examples(
                    question=query,
                    closest_docs_indices=closest_docs_indices,
                    metadata=self.metadata,
                )
                examples, features = self.processor_predict.fit_transform(
                    X=squad_examples)
                pred = self.reader.predict((examples, features), return_logit,
                                           n_predictions)
                predictions.append(pred)

            return predictions

        else:
            raise TypeError("The input is not a string or a list. \
                            Please provide a string or a list of strings as input"
                            )
Ejemplo n.º 2
0
    def predict(self, X=None):
        """ Compute prediction of an answer to a question

        Parameters
        ----------
        X: str or list of strings
            Sample (question) or list of samples to perform a prediction on

        Returns
        -------
        If X is str
        prediction: tuple (answer, title, paragraph)

        If X is list os strings
        predictions: list of tuples (answer, title, paragraph)

        """
        if (isinstance(X, str)):
            closest_docs_indices = self.retriever.predict(
                X, metadata=self.metadata)
            squad_examples = generate_squad_examples(
                question=X,
                closest_docs_indices=closest_docs_indices,
                metadata=self.metadata)
            examples, features = self.processor_predict.fit_transform(
                X=squad_examples)
            prediction = self.reader.predict((examples, features))
            return prediction

        elif (isinstance(X, list)):
            predictions = []
            for query in X:
                closest_docs_indices = self.retriever.predict(
                    query, metadata=self.metadata)
                squad_examples = generate_squad_examples(
                    question=query,
                    closest_docs_indices=closest_docs_indices,
                    metadata=self.metadata)
                examples, features = self.processor_predict.fit_transform(
                    X=squad_examples)
                pred = self.reader.predict((examples, features))
                predictions.append(pred)

            return predictions

        else:
            raise TypeError("The input is not a string or a list. \
                            Please provide a string or a list of strings as input"
                            )
Ejemplo n.º 3
0
    def predict(
        self,
        query: str = None,
        n_predictions: int = None,
        retriever_score_weight: float = 0.35,
        return_all_preds: bool = False,
    ):
        """ Compute prediction of an answer to a question

        Parameters
        ----------
        X: str
            Sample (question) to perform a prediction on

        n_predictions: int or None (default: None).
            Number of returned predictions. If None, only one prediction is return

        retriever_score_weight: float (default: 0.35).
            The weight of retriever score in the final score used for prediction.
            Given retriever score and reader average of start and end logits, the final score used for ranking is:

            final_score = retriever_score_weight * retriever_score + (1 - retriever_score_weight) * (reader_avg_logit)

        return_all_preds: boolean (default: False)
            whether to return a list of all predictions done by the Reader or not

        Returns
        -------
        if return_all_preds is False:
        prediction: tuple (answer, title, paragraph, score/logit)

        if return_all_preds is True:
        List of dictionnaries with all metadada of all answers outputted by the Reader
        given the question.

        """

        if not isinstance(query, str):
            raise TypeError(
                "The input is not a string. Please provide a string as input.")
        if not (isinstance(n_predictions, int) or n_predictions is None
                or n_predictions < 1):
            raise TypeError(
                "n_predictions should be a positive Integer or None")
        best_idx_scores = self.retriever.predict(query)
        squad_examples = generate_squad_examples(
            question=query,
            best_idx_scores=best_idx_scores,
            metadata=self.metadata,
            retrieve_by_doc=self.retrieve_by_doc,
        )
        examples, features = self.processor_predict.fit_transform(
            X=squad_examples)
        prediction = self.reader.predict(
            X=(examples, features),
            n_predictions=n_predictions,
            retriever_score_weight=retriever_score_weight,
            return_all_preds=return_all_preds,
        )
        return prediction
    def get_bert_predictions(self, query, retriever_result):
        metadata = self.cdqa_pipeline._expand_paragraphs(
            to_df(self.prediction_data))
        indexed_top_scores = self.get_indexed_top_scores(retriever_result)

        squad_examples = generate_squad_examples(
            question=query,
            best_idx_scores=indexed_top_scores,
            metadata=metadata,
            retrieve_by_doc=False)
        examples, features = self.cdqa_pipeline.processor_predict.fit_transform(
            X=squad_examples)

        return self.cdqa_pipeline.reader.predict(
            X=(examples, features),
            n_predictions=None,
            retriever_score_weight=self.retriever_score_weight,
            return_all_preds=True)