Beispiel #1
0
def get_chosen_logits(gold_answers, answers):
    correct, incorrect = get_correct_answers(gold_answers, answers)
    correct_logits = [
        ans.logits[label_to_id(ans.pred_label)] for ans in correct
    ]
    incorrect_logits = [
        ans.logits[label_to_id(ans.pred_label)] for ans in incorrect
    ]
    return correct_logits, incorrect_logits
Beispiel #2
0
 def __init__(
     self,
     example_id: str,
     pred_label: str,
     label: str = None,
     probs: List[float] = None,
     endings: Optional[List[str]] = None,
     logits: Optional[List[float]] = None,
     threshold: float = 0.0,
     no_answer: float = -1.0,
     no_answer_text: str = None,
     is_no_answer: bool = False,
 ):
     self.example_id = example_id
     self.pred_label = pred_label
     self.label = label
     self.probs = probs
     self.endings = endings
     self.logits = logits
     self.threshold = threshold
     self.no_answer = no_answer
     self.no_answer_text = no_answer_text
     self.probs_field = 'probs'
     self.is_no_answer = is_no_answer
     if not is_no_answer and self.pred_label is not None:
         self.is_no_answer = label_to_id(self.pred_label) == self.no_answer
Beispiel #3
0
 def get_labels(self, splits: Union[List[str], str] = None) -> List[int]:
     if isinstance(splits, str):
         splits = [splits]
     id_ans = {}
     for test in self.get_splits(splits):
         id_ans[test.example_id] = [label_to_id(a) for a in test]
     return id_ans
Beispiel #4
0
def get_correct_answers(gold_answers, answers):
    correct = []
    incorrect = []
    for gold, ans in zip(gold_answers, answers):
        # disable threshold, probs  mechanism
        if gold.get_answer() == label_to_id(ans.pred_label):
            correct.append(ans)
        else:
            incorrect.append(ans)
    return correct, incorrect
Beispiel #5
0
 def get_nof_choices(self) -> int:
     first_answer = self.answers[list(self.answers.keys())[0]]
     if first_answer.probs is not None:
         max_value = len(first_answer.probs)
     else:
         max_value = -1
         for ans in self.answers.values():
             ans_value = label_to_id(ans.pred_label)
             if ans_value > max_value:
                 max_value = ans_value
     return max_value
Beispiel #6
0
    def get_answer(self, accept_no_answer=True) -> int:
        ans = label_to_id(self.pred_label)
        if self.is_no_answer:
            ans = self.no_answer
        elif self.__getattribute__(self.probs_field) is not None:
            ans = self.no_answer
            if max(self.__getattribute__(self.probs_field)) > self.threshold:
                ans = argmax(self.__getattribute__(self.probs_field))

        if ans == self.no_answer and not accept_no_answer:
            ans = self.search_unanswerable_option()
        return ans
Beispiel #7
0
 def apply_no_answer(
     self,
     split: Union[List[str], str],
     answers: List[Answer],
     text: str,
 ):
     data = self.get_splits(split)
     if len(data) != len(answers):
         raise ValueError(
             'Asked to set no answer on a list with different size '
             'from dataset, maybe you asked for the wrong split?'
             f'(dataset size {len(data)}, nof answers: {len(answers)})')
     for datapoint, answer in zip(data, answers):
         assert (str(datapoint.example_id) == str(answer.example_id))
         ans_index = label_to_id(datapoint.label)
         answer_text = datapoint.endings[ans_index]
         found = answer_text.find(text) != -1
         if found and answer.get_answer() == ans_index:
             print(f'Aplying no answer to {answer.example_id}')
             answer.is_no_answer = True
     return answers
Beispiel #8
0
 def get_pred_tuple(self) -> List[Tuple[str, float]]:
     return [(label_to_id(self.label), self.get_answer())]