Exemple #1
0
 def testConvertToSpansMismatches(self):
     raw_text = u"America — Harvard, Yale"  # "—": em-dash (\u2014)
     tokens = ["America", u"\u2013", "Harvard", ",",
               "Yale"]  # \u2013: en-dash
     output = tokenizer_util.convert_to_spans(raw_text, tokens)
     expected_output = [(0, 7), (8, 9), (10, 17), (17, 18), (19, 23)]
     self.assertEqual(expected_output, output)
Exemple #2
0
def squad_generator(path,
                    tokenizer_fn=word_tokenize,
                    sort_by_length=False,
                    is_subword=False):
    """Generate SQuAD data from the raw json file."""

    with tf.gfile.GFile(path, 'r') as f:
        squad = json.load(f)

    examples = []
    for article in squad['data']:

        for paragraph in article['paragraphs']:
            context = paragraph['context'].strip()
            context_enc = context.encode('utf-8')

            context_tokens, context_ids = tokenizer_fn(context)
            for qa in paragraph['qas']:
                question = qa['question'].strip()
                id_ = qa['id']

                answers = [answer['text'].strip() for answer in qa['answers']]
                answer_starts = [
                    answer['answer_start'] for answer in qa['answers']
                ]
                answer_ends = [
                    start + len(answer)
                    for start, answer in zip(answer_starts, answers)
                ]

                feats = {}
                feats['id'] = id_
                feats['answers'] = utf_encode_list(answers)
                feats['num_answers'] = len(answers)

                feats['context'] = context_enc
                feats['context_tokens'] = context_tokens
                if context_ids:
                    feats['context_ids'] = context_ids
                feats['context_length'] = len(context_tokens)

                question_tokens, question_ids = tokenizer_fn(question)
                feats['question'] = question.encode('utf-8')
                feats['question_tokens'] = utf_encode_list(question_tokens)
                if question_ids:
                    feats['question_ids'] = question_ids
                feats['question_length'] = len(feats['question_tokens'])

                starts = []
                ends = []
                if is_subword:
                    for answer_start, answer in zip(answer_starts, answers):
                        # start, end = get_span(spans, answer_start, answer_end)
                        start, end = get_answer_index(
                            context=context,
                            context_tokens=feats['context_tokens'],
                            answer_start=answer_start,
                            answer=answer)
                        starts.append(start)
                        ends.append(end)
                else:
                    spans = tokenizer_util.convert_to_spans(
                        context, feats['context_tokens'])
                    starts = []
                    ends = []
                    for answer_start, answer_end in zip(
                            answer_starts, answer_ends):
                        start, end = get_span(spans, answer_start, answer_end)
                        starts.append(start)
                        ends.append(end)

                feats['answers_start_token'] = starts
                feats['answers_end_token'] = ends
                feats['context_tokens'] = utf_encode_list(
                    feats['context_tokens'])
                examples.append(feats)

    if sort_by_length:
        examples = sorted(examples, key=lambda x: len(x['context_tokens']))
    for example in examples:
        yield example
Exemple #3
0
 def testConvertToSpans(self):
     raw_text = "Convert to spans."
     tokens = ["Convert", "to", "spans", "."]
     output = tokenizer_util.convert_to_spans(raw_text, tokens)
     expected_output = [(0, 7), (8, 10), (11, 16), (16, 17)]
     self.assertEqual(expected_output, output)