def test_merge_token_labels_to_slot(self): data = get_test_sample() for i in data: self.assertEqual( test_utils.merge_token_labels_to_slot(i["token_ranges"], i["labels"]), i["output"], ) self.assertEqual( test_utils.merge_token_labels_to_slot( i["token_ranges"], [test_utils.strip_bio_prefix(l) for l in i["labels"]], use_bio_label=False, ), i["output"], )
def gen_extra_context(self): self.all_context["slots_prediction"] = [ merge_token_labels_to_slot( token_range, self.process_pred(word_pred[0:seq_len]), self.use_bio_labels, ) for word_pred, seq_len, token_range in zip( self.all_word_preds, self.all_context[DatasetFieldName.SEQ_LENS], self.all_context[DatasetFieldName.TOKEN_RANGE], ) ] self.all_context[DOC_LABEL_NAMES] = self.doc_label_names
def calculate_metric(self): return compute_prf1_metrics([ NodesPredictionPair( get_slots( merge_token_labels_to_slot( token_range, self.process_pred(pred[0:seq_len]), self.use_bio_labels, )), get_slots(slots_label), ) for pred, seq_len, token_range, slots_label in zip( self.all_preds, self.all_context[DatasetFieldName.SEQ_LENS], self.all_context[DatasetFieldName.TOKEN_RANGE], self.all_context[DatasetFieldName.RAW_WORD_LABEL], ) ])[1]