コード例 #1
0
 def build_data_mrc(data,
                    dump_path,
                    max_seq_len=MRC_MAX_SEQ_LEN,
                    tokenizer=None,
                    label_mapper=None,
                    is_training=True):
     with open(dump_path, 'w', encoding='utf-8') as writer:
         unique_id = 1000000000  # TODO: this is from BERT, needed to remove it...
         for example_index, sample in enumerate(data):
             ids = sample['uid']
             doc = sample['premise']
             query = sample['hypothesis']
             label = sample['label']
             doc_tokens, cw_map = squad_utils.token_doc(doc)
             answer_start, answer_end, answer, is_impossible = squad_utils.parse_squad_label(
                 label)
             answer_start_adjusted, answer_end_adjusted = squad_utils.recompute_span(
                 answer, answer_start, cw_map)
             is_valid = squad_utils.is_valid_answer(doc_tokens,
                                                    answer_start_adjusted,
                                                    answer_end_adjusted,
                                                    answer)
             if not is_valid: continue
             """
             TODO --xiaodl: support RoBERTa
             """
             feature_list = squad_utils.mrc_feature(tokenizer,
                                                    unique_id,
                                                    example_index,
                                                    query,
                                                    doc_tokens,
                                                    answer_start_adjusted,
                                                    answer_end_adjusted,
                                                    is_impossible,
                                                    max_seq_len,
                                                    MAX_QUERY_LEN,
                                                    DOC_STRIDE,
                                                    answer_text=answer,
                                                    is_training=True)
             unique_id += len(feature_list)
             for feature in feature_list:
                 so = json.dumps({
                     'uid': ids,
                     'token_id': feature.input_ids,
                     'mask': feature.input_mask,
                     'type_id': feature.segment_ids,
                     'example_index': feature.example_index,
                     'doc_span_index': feature.doc_span_index,
                     'tokens': feature.tokens,
                     'token_to_orig_map': feature.token_to_orig_map,
                     'token_is_max_context': feature.token_is_max_context,
                     'start_position': feature.start_position,
                     'end_position': feature.end_position,
                     'label': feature.is_impossible,
                     'doc': doc,
                     'doc_offset': feature.doc_offset,
                     'answer': [answer]
                 })
                 writer.write('{}\n'.format(so))
コード例 #2
0
 def build_data_premise_and_one_hypo(
         data,
         dump_path,
         max_seq_len=MAX_SEQ_LEN,
         tokenizer=None,
         encoderModelType=EncoderModelType.BERT):
     """Build data of sentence pair tasks
     """
     with open(dump_path, 'w', encoding='utf-8') as writer:
         for idx, sample in enumerate(data):
             ids = sample['uid']
             premise = sample['premise']
             hypothesis = sample['hypothesis']
             label = sample['label']
             if encoderModelType == EncoderModelType.ROBERTA:
                 input_ids, input_mask, type_ids = roberta_feature_extractor(
                     premise,
                     hypothesis,
                     max_seq_length=max_seq_len,
                     model=tokenizer)
                 features = {
                     'uid': ids,
                     'label': label,
                     'token_id': input_ids,
                     'type_id': type_ids,
                     'mask': input_mask
                 }
             elif encoderModelType == EncoderModelType.XLNET:
                 input_ids, input_mask, type_ids = xlnet_feature_extractor(
                     premise,
                     hypothesis,
                     max_seq_length=max_seq_len,
                     tokenize_fn=tokenizer)
                 features = {
                     'uid': ids,
                     'label': label,
                     'token_id': input_ids,
                     'type_id': type_ids,
                     'mask': input_mask
                 }
             else:
                 input_ids, _, type_ids = bert_feature_extractor(
                     premise,
                     hypothesis,
                     max_seq_length=max_seq_len,
                     tokenize_fn=tokenizer)
                 if task_type == TaskType.Span:
                     seg_a_start = len(type_ids) - sum(type_ids)
                     seg_a_end = len(type_ids)
                     answer_start, answer_end, answer, is_impossible = squad_utils.parse_squad_label(
                         label)
                     span_start, span_end = squad_utils.calc_tokenized_span_range(
                         premise, hypothesis, answer, answer_start,
                         answer_end, tokenizer, encoderModelType)
                     span_start = seg_a_start + span_start
                     span_end = min(seg_a_end, seg_a_start + span_end)
                     answer_tokens = tokenizer.convert_ids_to_tokens(
                         input_ids[span_start:span_end])
                     if span_start >= span_end:
                         span_start = -1
                         span_end = -1
                     features = {
                         'uid': ids,
                         'label': is_impossible,
                         'answer': answer,
                         "answer_tokens": answer_tokens,
                         "token_start": span_start,
                         "token_end": span_end,
                         'token_id': input_ids,
                         'type_id': type_ids
                     }
                 else:
                     features = {
                         'uid': ids,
                         'label': label,
                         'token_id': input_ids,
                         'type_id': type_ids
                     }
             writer.write('{}\n'.format(json.dumps(features)))
コード例 #3
0
from pytorch_pretrained_bert import BertTokenizer
from data_utils.task_def import EncoderModelType
from experiments.squad.squad_utils import calc_tokenized_span_range, parse_squad_label

model = "bert-base-uncased"
do_lower_case = True
tokenizer = BertTokenizer.from_pretrained(model, do_lower_case=do_lower_case)

for no, line in enumerate(open(r"data\canonical_data\squad_v2_train.tsv", encoding="utf-8")):
    if no % 1000 == 0:
        print(no)
    uid, label, context, question = line.strip().split("\t")
    answer_start, answer_end, answer, is_impossible = parse_squad_label(label)
    calc_tokenized_span_range(context, question, answer, answer_start, answer_end, tokenizer, EncoderModelType.BERT,
                              verbose=True)