def testConvertToSpansMismatches(self): raw_text = u"America — Harvard, Yale" # "—": em-dash (\u2014) tokens = ["America", u"\u2013", "Harvard", ",", "Yale"] # \u2013: en-dash output = tokenizer_util.convert_to_spans(raw_text, tokens) expected_output = [(0, 7), (8, 9), (10, 17), (17, 18), (19, 23)] self.assertEqual(expected_output, output)
def squad_generator(path, tokenizer_fn=word_tokenize, sort_by_length=False, is_subword=False): """Generate SQuAD data from the raw json file.""" with tf.gfile.GFile(path, 'r') as f: squad = json.load(f) examples = [] for article in squad['data']: for paragraph in article['paragraphs']: context = paragraph['context'].strip() context_enc = context.encode('utf-8') context_tokens, context_ids = tokenizer_fn(context) for qa in paragraph['qas']: question = qa['question'].strip() id_ = qa['id'] answers = [answer['text'].strip() for answer in qa['answers']] answer_starts = [ answer['answer_start'] for answer in qa['answers'] ] answer_ends = [ start + len(answer) for start, answer in zip(answer_starts, answers) ] feats = {} feats['id'] = id_ feats['answers'] = utf_encode_list(answers) feats['num_answers'] = len(answers) feats['context'] = context_enc feats['context_tokens'] = context_tokens if context_ids: feats['context_ids'] = context_ids feats['context_length'] = len(context_tokens) question_tokens, question_ids = tokenizer_fn(question) feats['question'] = question.encode('utf-8') feats['question_tokens'] = utf_encode_list(question_tokens) if question_ids: feats['question_ids'] = question_ids feats['question_length'] = len(feats['question_tokens']) starts = [] ends = [] if is_subword: for answer_start, answer in zip(answer_starts, answers): # start, end = get_span(spans, answer_start, answer_end) start, end = get_answer_index( context=context, context_tokens=feats['context_tokens'], answer_start=answer_start, answer=answer) starts.append(start) ends.append(end) else: spans = tokenizer_util.convert_to_spans( context, feats['context_tokens']) starts = [] ends = [] for answer_start, answer_end in zip( answer_starts, answer_ends): start, end = get_span(spans, answer_start, answer_end) starts.append(start) ends.append(end) feats['answers_start_token'] = starts feats['answers_end_token'] = ends feats['context_tokens'] = utf_encode_list( feats['context_tokens']) examples.append(feats) if sort_by_length: examples = sorted(examples, key=lambda x: len(x['context_tokens'])) for example in examples: yield example
def testConvertToSpans(self): raw_text = "Convert to spans." tokens = ["Convert", "to", "spans", "."] output = tokenizer_util.convert_to_spans(raw_text, tokens) expected_output = [(0, 7), (8, 10), (11, 16), (16, 17)] self.assertEqual(expected_output, output)