def test_merge_token_labels_to_slot(self): data = get_test_sample() for i in data: self.assertEqual( merge_token_labels_to_slot(i["token_ranges"], i["labels"]), i["output"]) self.assertEqual( merge_token_labels_to_slot( i["token_ranges"], [strip_bio_prefix(l) for l in i["labels"]], use_bio_label=False, ), i["output"], )
def _unnumberize(self, preds, tokens, doc_str): """ We re-tokenize and re-numberize the raw context (doc_str) here to get doc_tokens to get access to start_idx and end_idx mappings. At this point, ans_token_start is the start index of the answer within tokens and ans_token_end is the end index. We calculate the offset of doc_tokens within tokens. Then we find the start_idx and end_idx as well as the corresponding span in the raw text using the answer token indices. """ # start_idx and end_idx are lists of char start and end positions in doc_str. doc_tokens, start_idxs, end_idxs = self.tensorizer._lookup_tokens( doc_str) # find the offsets of doc_tokens in tokens try: offset_end = tokens.index( self.tensorizer.vocab.get_pad_index()) - 1 except ValueError: offset_end = len(tokens) - 1 offset_start = list( map( lambda x: tokens[x:offset_end] == doc_tokens[:offset_end - x], range(offset_end), )).index(True) # find each answer's char idxs and strings as well pred_labels = self._process_pred(preds[offset_start:offset_end]) token_range = list(zip(start_idxs, end_idxs)) pred_slots = parse_slot_string( merge_token_labels_to_slot( token_range, pred_labels, self.tensorizer.use_bio_labels, )) ans_strs = [] ans_start_char_idxs = [] ans_end_char_idxs = [] for slot in pred_slots: # if its not an answer span, skip if slot.label in map( str, [ self.tensorizer.labels_vocab.pad_token, Slot.NO_LABEL_SLOT, ], ): continue ans_strs.append(doc_str[slot.start:slot.end]) ans_start_char_idxs.append(slot.start) ans_end_char_idxs.append(slot.end) return ans_strs, ans_start_char_idxs, ans_end_char_idxs
def gen_extra_context(self): self.all_context["slots_prediction"] = [ merge_token_labels_to_slot( token_range[0:seq_len], self.process_pred(word_pred[0:seq_len]), self.use_bio_labels, ) for word_pred, seq_len, token_range in zip( self.all_word_preds, self.all_context[DatasetFieldName.SEQ_LENS], self.all_context[DatasetFieldName.TOKEN_RANGE], ) ] self.all_context[DOC_LABEL_NAMES] = self.doc_label_names
def calculate_metric(self): return compute_prf1_metrics([ NodesPredictionPair( get_slots( merge_token_labels_to_slot( token_range, self.process_pred(pred[0:seq_len]), self.use_bio_labels, )), get_slots(slots_label), ) for pred, seq_len, token_range, slots_label in zip( self.all_preds, self.all_context[DatasetFieldName.SEQ_LENS], self.all_context[DatasetFieldName.TOKEN_RANGE], self.all_context[DatasetFieldName.RAW_WORD_LABEL], ) ])[1]
def aggregate_preds(self, batch_preds, batch_context): intent_preds, word_preds = batch_preds self.all_preds.extend([ create_frame( text, self.doc_label_names[intent_pred], merge_token_labels_to_slot( token_range[0:seq_len], [self.word_label_names[p] for p in word_pred[0:seq_len]], self.use_bio_labels, ), byte_length(text), ) for text, intent_pred, word_pred, seq_len, token_range in zip( batch_context[self.text_column_name], intent_preds, word_preds, batch_context[DatasetFieldName.SEQ_LENS], batch_context[DatasetFieldName.TOKEN_RANGE], ) ])