Exemple #1
0
    def find_span_with_gt(self, context, offsets, ground_truth):
        best_f1 = 0.0
        best_span = (len(offsets) - 1, len(offsets) - 1)
        gt = self.normalize_answer(pre_proc(ground_truth)).split()

        ls = [
            i for i in range(len(offsets))
            if context[offsets[i][0]:offsets[i][1]].lower() in gt
        ]

        for i in range(len(ls)):
            for j in range(i, len(ls)):
                pred = self.normalize_answer(
                    pre_proc(
                        context[offsets[ls[i]][0]:offsets[ls[j]][1]])).split()
                common = Counter(pred) & Counter(gt)
                num_same = sum(common.values())
                if num_same > 0:
                    precision = 1.0 * num_same / len(pred)
                    recall = 1.0 * num_same / len(gt)
                    f1 = (2 * precision * recall) / (precision + recall)
                    if f1 > best_f1:
                        best_f1 = f1
                        best_span = (ls[i], ls[j])
        return best_span
Exemple #2
0
    def preprocess(self, dataset_label):
        file_name = self.train_file if dataset_label == 'train' else (self.dev_file if dataset_label == 'dev' else self.test_file)
        output_file_name = os.path.join(self.spacyDir, self.data_prefix + dataset_label + '-preprocessed.json')

        print('Preprocessing', dataset_label, 'file:', file_name)
        print('Loading json...')
        with open(file_name, 'r') as f:
            dataset = json.load(f)

        print('Processing json...')

        data = []
        tot = len(dataset['data'])
        for data_idx in tqdm(range(tot)):
            datum = dataset['data'][data_idx]
            context_str = datum['story']
            _datum = {'context': context_str,
                      'source': datum['source'],
                      'id': datum['id'],
                      'filename': datum['filename']}

            nlp_context = nlp(pre_proc(context_str))
            _datum['annotated_context'] = self.process(nlp_context)
            _datum['raw_context_offsets'] = self.get_raw_context_offsets(_datum['annotated_context']['word'], context_str)
            _datum['qas'] = []
            assert len(datum['questions']) == len(datum['answers'])

            additional_answers = {}
            if 'additional_answers' in datum:
                for k, answer in datum['additional_answers'].items():
                    if len(answer) == len(datum['answers']):
                        for ex in answer:
                            idx = ex['turn_id']
                            if idx not in additional_answers:
                                additional_answers[idx] = []
                            additional_answers[idx].append(ex['input_text']) # additional_answer is only used to eval, so raw_text is fine

            for i in range(len(datum['questions'])):
                question, answer = datum['questions'][i], datum['answers'][i]
                assert question['turn_id'] == answer['turn_id']

                idx = question['turn_id']
                _qas = {'turn_id': idx,
                        'question': question['input_text'],
                        'answer': answer['input_text']}
                if idx in additional_answers:
                    _qas['additional_answers'] = additional_answers[idx]

                _qas['annotated_question'] = self.process(nlp(pre_proc(question['input_text'])))
                _qas['annotated_answer'] = self.process(nlp(pre_proc(answer['input_text'])))
                _qas['raw_answer'] = answer['input_text']
                _qas['answer_span_start'] = answer['span_start']
                _qas['answer_span_end'] = answer['span_end']

                start = answer['span_start']
                end = answer['span_end']
                chosen_text = _datum['context'][start: end].lower()
                while len(chosen_text) > 0 and chosen_text[0] in string.whitespace:
                    chosen_text = chosen_text[1:]
                    start += 1
                while len(chosen_text) > 0 and chosen_text[-1] in string.whitespace:
                    chosen_text = chosen_text[:-1]
                    end -= 1
                input_text = _qas['answer'].strip().lower()
                if input_text in chosen_text:
                    p = chosen_text.find(input_text)
                    _qas['answer_span'] = self.find_span(_datum['raw_context_offsets'],
                                                    start + p, start + p + len(input_text))
                else:
                    _qas['answer_span'] = self.find_span_with_gt(_datum['context'],
                                                            _datum['raw_context_offsets'], input_text)
                long_question = ''
                for j in range(i - 2, i + 1):
                    if j < 0:
                        continue
                    long_question += ' ' + datum['questions'][j]['input_text']
                    if j < i:
                        long_question += ' ' + datum['answers'][j]['input_text']

                long_question = long_question.strip()       
                nlp_long_question = nlp(long_question)
                _qas['context_features'] = feature_gen(nlp_context, nlp_long_question)
                    
                _datum['qas'].append(_qas)
            data.append(_datum)

        # build vocabulary
        if dataset_label == 'train':
            print('Build vocabulary from training data...')
            contexts = [_datum['annotated_context']['word'] for _datum in data]
            qas = [qa['annotated_question']['word'] + qa['annotated_answer']['word'] for qa in _datum['qas'] for _datum in data]
            self.train_vocab = self.build_vocab(contexts, qas)
            self.train_char_vocab = self.build_char_vocab(self.train_vocab)

        print('Getting word ids...')
        w2id = {w: i for i, w in enumerate(self.train_vocab)}
        c2id = {c: i for i, c in enumerate(self.train_char_vocab)}
        for _datum in data:
            _datum['annotated_context']['wordid'] = token2id_sent(_datum['annotated_context']['word'], w2id, unk_id = 1, to_lower = False)
            _datum['annotated_context']['charid'] = char2id_sent(_datum['annotated_context']['word'], c2id, unk_id = 1, to_lower = False)
            for qa in _datum['qas']:
                qa['annotated_question']['wordid'] = token2id_sent(qa['annotated_question']['word'], w2id, unk_id = 1, to_lower = False)
                qa['annotated_question']['charid'] = char2id_sent(qa['annotated_question']['word'], c2id, unk_id = 1, to_lower = False)
                qa['annotated_answer']['wordid'] = token2id_sent(qa['annotated_answer']['word'], w2id, unk_id = 1, to_lower = False)
                qa['annotated_answer']['charid'] = char2id_sent(qa['annotated_answer']['word'], c2id, unk_id = 1, to_lower = False)

        if dataset_label == 'train':
            # get the condensed dictionary embedding
            print('Getting embedding matrix for ' + dataset_label)
            embedding = build_embedding(self.glove_file, self.train_vocab, self.glove_dim)
            meta = {'vocab': self.train_vocab, 'char_vocab': self.train_char_vocab, 'embedding': embedding.tolist()}
            meta_file_name = os.path.join(self.spacyDir, dataset_label + '_meta.msgpack')
            print('Saving meta information to', meta_file_name)
            with open(meta_file_name, 'wb') as f:
                msgpack.dump(meta, f, encoding='utf8')

        dataset['data'] = data

        if dataset_label == 'test':
            return dataset

        with open(output_file_name, 'w') as output_file:
            json.dump(dataset, output_file, sort_keys=True, indent=4)
Exemple #3
0
    def preprocess(self, dataset_label):
        file_name = self.train_file if dataset_label == 'train' else (
            self.dev_file if dataset_label == 'dev' else self.test_file)
        output_file_name = os.path.join(
            self.spacyDir,
            self.data_prefix + dataset_label + '-preprocessed.json')

        print('Preprocessing', dataset_label, 'file:', file_name)
        print('Loading json...')
        with open(file_name, 'r') as f:
            dataset = json.load(f)

        print('Processing json...')

        dict1 = ['where', 'when', 'who']
        data = []
        tot = len(dataset['data'])
        type1 = type2 = 0
        for data_idx in tqdm(range(tot)):
            datum = dataset['data'][data_idx]
            context_str = datum['story']
            _datum = {
                'context': context_str,
                'source': datum['source'],
                'id': datum['id']
            }

            nlp_context = nlp(pre_proc(context_str))
            _datum['annotated_context'] = self.process(nlp_context)
            _datum['raw_context_offsets'] = self.get_raw_context_offsets(
                _datum['annotated_context']['word'], context_str)
            _datum['qas'] = []

            assert len(datum['questions']) == len(datum['answers'])

            for i in range(len(datum['questions'])):
                question, answer = datum['questions'][i], datum['answers'][i]
                assert question['turn_id'] == answer['turn_id']

                idx = question['turn_id']
                _qas = {
                    'turn_id': idx,
                    'question': question['input_text'],
                    'answer': answer['input_text']
                }

                _qas['annotated_question'] = self.process(
                    nlp(pre_proc(question['input_text'])))

                _qas['annotated_answer'] = self.process(
                    nlp(pre_proc(answer['input_text'])))
                _qas['raw_answer'] = answer['input_text']
                _qas['span_text'] = answer['span_text']

                tmp = _qas['raw_answer']
                tmp = self.removePunctuation(tmp)
                if _qas['raw_answer'] in context_str or tmp.lower() in [
                        "yes", "no", "unknown"
                ]:
                    type1 += 1
                    _qas['answer_type'] = "extractive"
                else:
                    type2 += 1
                    _qas['answer_type'] = "generative"
                _qas['answer_span_start'] = answer['span_start']
                _qas['answer_span_end'] = answer['span_end']

                sign = ""
                ques = question['input_text'].lower()
                real_ans = answer['input_text'].lower()
                real = self.remove_punctual(real_ans)
                real = real.split()

                for word in dict1:
                    if word in ques or ques[:
                                            3] == "was" or ques[:
                                                                4] == 'were' or ques[:
                                                                                     2] == 'is':
                        sign = "factual"
                        break

                if len(real) <= 4:
                    sign = "factual"
                if not sign or real_ans == "no" or real_ans == "yes" or real_ans == 'unknown':
                    sign = "factual"

                _qas['question_type'] = sign

                start = answer['span_start']  #rational 范围
                end = answer['span_end']
                chosen_text = _datum['context'][start:end].lower()
                while len(chosen_text) > 0 and chosen_text[
                        0] in string.whitespace:  #判断开头的空白符 \t,\n等6种
                    chosen_text = chosen_text[1:]
                    start += 1
                while len(chosen_text) > 0 and chosen_text[
                        -1] in string.whitespace:  # 判断结尾的空白符
                    chosen_text = chosen_text[:-1]
                    end -= 1
                input_text = _qas['answer'].strip().lower()
                if input_text in chosen_text:
                    p = chosen_text.find(input_text)  # p:input_text的起始值
                    _qas['answer_span'] = self.find_span(
                        _datum['raw_context_offsets'], start + p,
                        start + p + len(input_text))
                else:
                    _qas['answer_span'] = self.find_span_with_gt(
                        _datum['context'], _datum['raw_context_offsets'],
                        input_text)

                _datum['qas'].append(_qas)
            data.append(_datum)

        # build vocabulary
        if dataset_label == 'train':
            print('Build vocabulary from training data...')
            contexts = [_datum['annotated_context']['word'] for _datum in data]
            qas = [
                qa['annotated_question']['word'] +
                qa['annotated_answer']['word'] for qa in _datum['qas']
                for _datum in data
            ]
            self.train_vocab = self.build_vocab(contexts, qas)

        print('Getting word ids...')
        w2id = {w: i for i, w in enumerate(self.train_vocab)}
        for _datum in data:
            _datum['annotated_context']['wordid'] = token2id_sent(
                _datum['annotated_context']['word'],
                w2id,
                unk_id=1,
                to_lower=False)
            #new modify, get wordid
            for qa in _datum['qas']:
                qa['annotated_question']['wordid'] = token2id_sent(
                    qa['annotated_question']['word'],
                    w2id,
                    unk_id=1,
                    to_lower=False)
                qa['annotated_answer']['wordid'] = token2id_sent(
                    qa['annotated_answer']['word'],
                    w2id,
                    unk_id=1,
                    to_lower=False)

        if dataset_label == 'train':
            # get the condensed dictionary embedding
            print('Getting embedding matrix for ' + dataset_label)
            embedding = build_embedding(self.glove_file, self.train_vocab,
                                        self.glove_dim)
            meta = {'vocab': self.train_vocab, 'embedding': embedding.tolist()}
            meta_file_name = os.path.join(self.spacyDir,
                                          dataset_label + '_meta.msgpack')
            print('Saving meta information to', meta_file_name)
            with open(meta_file_name, 'wb') as f:
                # msgpack.dump(meta, f, encoding='utf8')
                msgpack.dump(meta, f)

        dataset['data'] = data

        if dataset_label == 'test':
            return dataset

        with open(output_file_name, 'w') as output_file:
            json.dump(dataset, output_file, sort_keys=True, indent=4)
        print("The amount of extractive qa is: ", type1)
        print("The amount of generative qa is: ", type2)
    def preprocess(self, dataset_label):
        # file_name = self.train_file if dataset_label == 'train' else (self.dev_file if dataset_label == 'dev' else self.test_file)
        file_name = os.path.join(self.opt['datadir'],
                                 self.opt[dataset_label + '_FILE'])
        output_file_name = os.path.join(
            self.spacyDir, dataset_label + '-preprocessed.msgpack')
        log.info('Preprocessing : {}\n File : {}'.format(
            dataset_label, file_name))
        print('Loading json...')
        # with open(file_name, 'r') as f:
        #     dataset = json.load(f)
        with open(file_name, 'rb') as f:
            dataset = msgpack.load(f, encoding='utf8')
            # if 'DEBUG' in self.opt:
            #     dataset['data'] = dataset['data'][:10]
        if self.BuildTestVocabulary and dataset_label == 'train':
            data_len = [len(dataset['data'])]
            data_list = []
            output_file_name_list = []
            for _dataset_label in self.dataset_labels[1:]:
                _file_name = os.path.join(self.opt['datadir'],
                                          self.opt[_dataset_label + '_FILE'])
                with open(_file_name, 'rb') as f:
                    _dataset = msgpack.load(f, encoding='utf8')
                    # if 'DEBUG' in self.opt:
                    #     _dataset['data'] = _dataset['data'][:10]
                    data_list.append(_dataset)
                    data_len.append(len(_dataset['data']))
                    dataset['data'].extend(_dataset['data'])
                    output_file_name_list.append(
                        os.path.join(self.spacyDir,
                                     _dataset_label + '-preprocessed.msgpack'))
            # print(data_len)
            # print(len(data_list))
            # print(data_st_idx)
            # print(len(dataset['data']))
            # print(self.dataset_labels)
        # else:
        #     test_id = {}

        print('Processing json...')

        data = []
        tot = len(dataset['data'])
        # tot = len(dataset['data'])
        ocr = 0
        od = 0
        yolo = 0
        quetion_str = []
        ans_str = []
        ocr_str = []
        od_str = []

        ocr_dict = {}
        od_dict = {}
        n_gram = self.n_gram

        #span for distinguishing the ocr distractors and 2-grams
        # len_ocr = []
        # len_2_gram = []

        over_range = []
        dis_pos_pad = [0 for i in range(8)]
        zero_len_ans = 0
        ocr_name_list_gram = [
            'OCR_gram2', 'TextSpotter_gram2', 'ensemble_ocr_gram2',
            'two_stage_OCR_gram2', 'OCR_corner_gram2', 'PMTD_MORAN_gram2'
        ]
        ocr_name_list = [
            'distractors', 'OCR', 'TextSpotter', 'ensemble_ocr',
            'two_stage_OCR', 'OCR_corner', 'PMTD_MORAN', 'ES_ocr', 'ES_ocr_30'
        ]
        od_name_list = ['OD', 'YOLO', 'OD_bottom-up']
        # ES_ocr_list = []
        if 'preprocess_ocr_name' in self.opt:
            ocr_name_list = self.opt['preprocess_ocr_name'].split(',')
            ocr_name_list_gram = [
                t + '_gram' + str(self.opt['n_gram']) for t in ocr_name_list
                if t != 'distractors' and 'ES_ocr' not in t
            ]
        if 'preprocess_od_name' in self.opt:
            od_name_list = self.opt['preprocess_od_name'].split(',')

        for data_idx in tqdm(range(tot)):
            datum = dataset['data'][data_idx]
            dis_ocr = []
            if 'distractors' in datum:
                for _dis in datum['distractors']:
                    if len(_dis) == 0:
                        zero_len_ans += 1
                        _dis = '#'
                    dis_ocr.append({'word': _dis, 'pos': dis_pos_pad})
                #assert len(dis_ocr) == 100
                datum['distractors'] = dis_ocr
            if 'answers' not in datum:
                datum['answers'] = []
            # if len(datum['OCR']) == 0:
            #     continue

            que_str = datum['question'].lower()
            _datum = {
                'question': datum['question'],
                'filename': datum['file_path'],
                'question_id': datum['question_id'],
            }
            # if 'ES_ocr' in ocr_name_list:
            #     if 'ES_ocr_len' in datum:
            #         _datum['ES_ocr_len'] = datum['ES_ocr_len']
            #     else:
            #         _datum['ES_ocr_len'] = len(datum['ES_ocr'])
            #         assert _datum['ES_ocr_len'] == 100

            quetion_str.append(que_str)

            ans_str.extend([item.lower() for item in datum['answers']])
            _datum['orign_answers'] = datum['answers']

            _datum['OCR'] = []
            # _datum['distractors'] = []
            assert 'image_width' in datum
            assert 'image_height' in datum
            width = datum['image_width']
            height = datum['image_height']
            # ocr_name_list = ['distractors', 'OCR', 'ensemble_ocr', 'TextSpotter']
            pos_padding = [0 for _ in range(8)]
            for _ocr_name in ocr_name_list:
                _datum[_ocr_name] = []
                if _ocr_name not in datum:
                    datum[_ocr_name] = []
                for i in range(len(datum[_ocr_name])):
                    original = datum[_ocr_name][i]['word']
                    word = datum[_ocr_name][i]['word'].lower()
                    if word not in ocr_dict:
                        ocr_dict[word] = len(ocr_str)
                        ocr_str.append(datum[_ocr_name][i]['word'].lower())
                    if 'pos' not in datum[_ocr_name][i]:
                        ocr_pos = pos_padding
                    else:
                        ocr_pos = datum[_ocr_name][i]['pos']

                    for j in range(4):
                        ocr_pos[2 * j] = ocr_pos[2 * j] / width
                        ocr_pos[2 * j + 1] = ocr_pos[2 * j + 1] / height
                    for j in ocr_pos:
                        if not (j <= 1 and 0 <= j):
                            over_range.append(j)
                    ocr_tmp = {
                        'word': word,
                        'pos': ocr_pos,
                        'original': original,
                        'ANLS': 0,
                        'ACC': 0
                    }
                    if 'cnt' in datum[_ocr_name][i]:
                        ocr_tmp['cnt'] = datum[_ocr_name][i]['cnt']
                    if 'ANLS' in datum[_ocr_name][i]:
                        ocr_tmp['ANLS'] = datum[_ocr_name][i]['ANLS']
                    if 'ACC' in datum[_ocr_name][i]:
                        ocr_tmp['ACC'] = datum[_ocr_name][i]['ACC']
                    _datum[_ocr_name].append(ocr_tmp)
            for _od_name in od_name_list:
                _datum[_od_name] = []
                for i in range(len(datum[_od_name])):
                    original = datum[_od_name][i]['object']
                    _od_word = datum[_od_name][i]['object'].lower()
                    if _od_word not in od_dict:
                        od_dict[_od_word] = len(od_str)
                        od_str.append(_od_word)
                    # od_str.append(datum[_od_name][i]['object'].lower())
                    _od_pos = datum[_od_name][i]['pos']
                    od_pos = []
                    _width = int(_od_pos[2] / 2)
                    _height = int(_od_pos[3] / 2)
                    od_pos.extend([_od_pos[0] - _width, _od_pos[1] - _height])
                    od_pos.extend([_od_pos[0] + _width, _od_pos[1] - _height])
                    od_pos.extend([_od_pos[0] + _width, _od_pos[1] + _height])
                    od_pos.extend([_od_pos[0] - _width, _od_pos[1] + _height])
                    for i in range(4):
                        od_pos[2 * i] = od_pos[2 * i] / width
                        od_pos[2 * i + 1] = od_pos[2 * i + 1] / height
                    for i in od_pos:
                        if not (i <= 1 and 0 <= i):
                            over_range.append(i)
                    _datum[_od_name].append({
                        'object': _od_word,
                        'pos': od_pos,
                        'original': original
                    })
            data.append(_datum)
        # print('\nod num: {}\t ocr num: {}'.format(od, ocr))
        # log.info()
        log.info('ZERO LEGNTH ANS: {}'.format(zero_len_ans))
        log.info('length of data: {}'.format(len(data)))
        del dataset
        #thread = multiprocessing.cpu_count()
        thread = 20
        log.info('Using {} threads to takenize'.format(thread))
        que_iter = (pre_proc(c) for c in quetion_str)
        ans_iter = (pre_proc(c) for c in ans_str)
        ocr_iter = (pre_proc(c) for c in ocr_str)
        od_iter = (pre_proc(c) for c in od_str)
        # yolo_iter = (pre_proc(c) for c in yolo_str)
        que_docs = [
            doc for doc in nlp.pipe(que_iter, batch_size=64, n_threads=thread)
        ]
        ans_docs = [
            doc for doc in nlp.pipe(ans_iter, batch_size=64, n_threads=thread)
        ]
        ocr_docs = [
            doc for doc in nlp.pipe(ocr_iter, batch_size=64, n_threads=thread)
        ]
        od_docs = [
            doc for doc in nlp.pipe(od_iter, batch_size=64, n_threads=thread)
        ]
        # yolo_docs = [doc for doc in nlp.pipe(yolo_iter, batch_size=64, n_threads=thread)]
        assert len(que_docs) == len(quetion_str)
        assert len(ans_docs) == len(ans_str)
        assert len(ocr_docs) == len(ocr_str)
        assert len(od_docs) == len(od_str)
        # assert len(yolo_docs) == len(yolo_str)
        ocr_output = [self.process(item) for item in ocr_docs]
        od_output = [self.process(item) for item in od_docs]

        que_idx = ans_idx = ocr_idx = od_idx = yolo_idx = 0
        for _datum in tqdm(data):
            _datum['annotated_question'] = self.process(que_docs[que_idx])
            _datum['raw_question_offsets'] = self.get_raw_context_offsets(
                _datum['annotated_question']['word'], quetion_str[que_idx])
            que_idx += 1
            _datum['answers'] = []
            for i in _datum['orign_answers']:
                _datum['answers'].append(self.process(ans_docs[ans_idx]))
                ans_idx += 1
            for _ocr_name in ocr_name_list:
                for i in range(len(_datum[_ocr_name])):
                    # output = self.process(ocr_docs[ocr_idx])
                    # ocr_idx += 1
                    tmp_ocr = ocr_dict[_datum[_ocr_name][i]['word']]
                    if len(ocr_output[tmp_ocr]['word']) != 1:
                        ocr += 1
                    _datum[_ocr_name][i]['word'] = ocr_output[tmp_ocr]
                    ocr_idx += 1
            for _od_name in od_name_list:
                for i in range(len(_datum[_od_name])):
                    tmp_od = od_dict[_datum[_od_name][i]['object']]
                    output = od_output[tmp_od]
                    od_idx += 1
                    if len(output['word']) != 1:
                        od += 1
                    _datum[_od_name][i]['object'] = output
        assert len(que_docs) == que_idx
        assert len(ans_docs) == ans_idx
        # assert len(ocr_docs) == ocr_idx
        # assert len(od_docs) == od_idx
        # assert len(yolo_docs) == yolo_idx
        log.info('od: {} \t ocr: {} \t yolo: {}'.format(od, ocr, yolo))

        # build vocabulary
        if dataset_label == 'train':
            print('Build vocabulary from training data...')
            contexts = [
                _datum['annotated_question']['word'] for _datum in data
            ]
            _words = []
            for _datum in data:
                for _ocr_name in ocr_name_list:
                    _words.extend(
                        [item['word']['word'] for item in _datum[_ocr_name]])
                for _od_name in od_name_list:
                    _words.extend(
                        [item['object']['word'] for item in _datum[_od_name]])
            # ocr = [item['word']['word'] for item in _datum['OCR'] for _datum in data]
            # od = [item['object']['word'] for item in _datum['OD'] for _datum in data]
            # yolo = [item['object']['word'] for item in _datum['YOLO'] for _datum in data]
            ans = [t['word'] for _datum in data for t in _datum['answers']]
            if "FastText" in self.opt:
                self.train_vocab = self.build_all_vocab(contexts + _words, ans)
            elif 'GLOVE' in self.opt:
                self.train_vocab = self.build_vocab(contexts + _words, ans)

            self.train_char_vocab = self.build_char_vocab(self.train_vocab)
            del contexts
        print('Getting word ids...')
        w2id = {w: i for i, w in enumerate(self.train_vocab)}
        c2id = {c: i for i, c in enumerate(self.train_char_vocab)}
        que_oov = ocr_oov = 0
        od_oov = [0 for t in od_name_list]
        ocr_oov = [0 for t in ocr_name_list]
        od_token_total = [0 for t in od_name_list]
        ocr_token_total = [0 for t in ocr_name_list]
        que_token_total = 0
        ocr_m1 = ocr_m2 = 0
        for _i, _datum in enumerate(data):
            _datum['annotated_question']['wordid'], oov, l = token2id_sent(
                _datum['annotated_question']['word'],
                w2id,
                unk_id=1,
                to_lower=False)
            que_oov += oov
            que_token_total += l
            _datum['annotated_question']['charid'] = char2id_sent(
                _datum['annotated_question']['word'],
                c2id,
                unk_id=1,
                to_lower=False)
            for _ocr_name_idx, _ocr_name in enumerate(ocr_name_list):
                for ocr_i, ocr in enumerate(_datum[_ocr_name]):
                    ocr['word']['wordid'], oov, l = token2id_sent(
                        ocr['word']['word'], w2id, unk_id=1, to_lower=False)
                    ocr_oov[_ocr_name_idx] += oov
                    ocr_token_total[_ocr_name_idx] += l
                    ocr['word']['charid'] = char2id_sent(ocr['word']['word'],
                                                         c2id,
                                                         unk_id=1,
                                                         to_lower=False)
            for _od_name_idx, _od_name in enumerate(od_name_list):
                for ocr_i, ocr in enumerate(_datum[_od_name]):
                    ocr['object']['wordid'], oov, l = token2id_sent(
                        ocr['object']['word'], w2id, unk_id=1, to_lower=False)
                    od_oov[_od_name_idx] += oov
                    od_token_total[_od_name_idx] += l
                    ocr['object']['charid'] = char2id_sent(
                        ocr['object']['word'], c2id, unk_id=1, to_lower=False)

            for _gram_name in ocr_name_list_gram:
                _datum[_gram_name] = []
                _ocr_name = _gram_name[:-6]
                n = int(_gram_name[-1])
                for i in range(len(_datum[_ocr_name])):
                    if i + n > len(_datum[_ocr_name]):
                        break
                    tmp = ' '.join(
                        [t['original'] for t in _datum[_ocr_name][i:i + n]])
                    tmp = tmp.lower()
                    word = {}
                    new_pos = []
                    for j in range(i, i + n):
                        if len(new_pos) == 0:
                            new_pos = deepcopy(_datum[_ocr_name][j]['pos'])
                        else:
                            for pos_idx in range(len(new_pos)):
                                if pos_idx == 0 or pos_idx == 1 or pos_idx == 3 or pos_idx == 4:
                                    new_pos[pos_idx] = min(
                                        new_pos[pos_idx],
                                        _datum[_ocr_name][j]['pos'][pos_idx])
                                else:
                                    new_pos[pos_idx] = max(
                                        new_pos[pos_idx],
                                        _datum[_ocr_name][j]['pos'][pos_idx])
                        for k, v in _datum[_ocr_name][j]['word'].items():
                            if k not in word:
                                word[k] = deepcopy(v)
                            else:
                                word[k] += deepcopy(v)
                    if len(_datum['orign_answers']) == 0:
                        _acc = _anls = 0
                    else:
                        _acc = eval_func.note_textvqa(_datum['orign_answers'],
                                                      tmp)
                        _anls = eval_func.note_stvqa(_datum['orign_answers'],
                                                     tmp)
                    _datum[_gram_name].append({
                        'word': word,
                        'pos': new_pos,
                        'original': tmp,
                        'ANLS': _anls,
                        'ACC': _acc
                    })
                    # gram2_token_list.extend()
                for item in _datum[_gram_name]:
                    for wordid_item in item['word']['wordid']:
                        if type(wordid_item) is list:
                            assert False
        lines = ['|name|total token|oov|oov percentage|', '|:-:|:-:|:-:|:-:|']
        lines.append('|question oov|{}|{}|{}|'.format(
            que_oov, que_token_total, que_oov / que_token_total))
        print('question oov: {} / {} = {}'.format(que_oov, que_token_total,
                                                  que_oov / que_token_total))
        for _ocr_name_idx, _ocr_name in enumerate(ocr_name_list):
            print('{} oov: {} / {} = {}'.format(
                _ocr_name, ocr_oov[_ocr_name_idx],
                ocr_token_total[_ocr_name_idx],
                ocr_oov[_ocr_name_idx] / ocr_token_total[_ocr_name_idx]))
            lines.append('|{}|{}|{}|{}|'.format(
                _ocr_name, ocr_oov[_ocr_name_idx],
                ocr_token_total[_ocr_name_idx],
                ocr_oov[_ocr_name_idx] / ocr_token_total[_ocr_name_idx]))
        for _od_name_idx, _od_name in enumerate(od_name_list):
            print('{} oov: {} / {} = {}'.format(
                _od_name, od_oov[_od_name_idx], od_token_total[_od_name_idx],
                od_oov[_od_name_idx] / od_token_total[_od_name_idx]))
            lines.append('|{}|{}|{}|{}|'.format(
                _od_name, od_oov[_od_name_idx], od_token_total[_od_name_idx],
                od_oov[_od_name_idx] / od_token_total[_od_name_idx]))
        with open(os.path.join(self.spacyDir, 'oov.md'), 'w') as f:
            f.write('\n'.join(lines))

        dataset = {}
        if dataset_label == 'train':
            # get the condensed dictionary embedding
            print('Getting embedding matrix for ' + dataset_label)
            meta = {
                'vocab': self.train_vocab,
                'char_vocab': self.train_char_vocab
            }
            if 'FastText' in self.opt:
                fast_embedding = build_fasttext_embedding(
                    self.fasttext_model, self.train_vocab,
                    self.opt['fast_dim'])
                meta['fast_embedding'] = fast_embedding.tolist()
            if 'GLOVE' in self.opt:
                glove_embedding = build_embedding(self.glove_file,
                                                  self.train_vocab,
                                                  self.glove_dim)
                meta['glove_embedding'] = glove_embedding.tolist()
            if 'PHOC' in self.opt:
                phoc_embedding = build_phoc_embedding(self.train_vocab,
                                                      self.phoc_size)
                meta['phoc_embedding'] = phoc_embedding.tolist()
            meta_file_name = os.path.join(self.spacyDir,
                                          dataset_label + '_meta.msgpack')
            print('Saving meta information to', meta_file_name)
            with open(meta_file_name, 'wb') as f:
                msgpack.dump(meta, f, encoding='utf8')
        if self.BuildTestVocabulary and dataset_label == 'train':
            data_st_idx = data_len[0]
            for _data_idx, _data in enumerate(data_list):
                _data['data'] = data[data_st_idx:data_st_idx +
                                     data_len[_data_idx + 1]]
                data_st_idx += data_len[_data_idx + 1]
                with open(output_file_name_list[_data_idx], 'wb') as wf:
                    msgpack.dump(_data, wf, encoding='utf8')
            else:
                assert data_st_idx == len(data)
            dataset['data'] = data[:data_len[0]]
        else:
            dataset['data'] = data

        # if dataset_label == 'test':
        #     return dataset

        # with open(output_file_name, 'w') as output_file:
        #     json.dump(dataset, output_file, sort_keys=True, indent=2)
        with open(output_file_name, 'wb') as output_file:
            msgpack.dump(dataset, output_file, encoding='utf8')
        log.info('Preprocessing over')
Exemple #5
0
    def preprocess(self, dataset_label):
        file_name = self.train_file if dataset_label == 'train' else (
            self.dev_file if dataset_label == 'dev' else self.test_file)
        output_file_name = os.path.join(
            self.spacyDir,
            self.data_prefix + dataset_label + '-preprocessed.json')

        print('Preprocessing', dataset_label, 'file:', file_name)
        print('Loading json...')
        with open(file_name, 'r') as f:
            dataset = json.load(f)

        print('Processing json...')
        count = 0
        data = []
        tot = len(dataset['data'])
        type1 = type2 = 0
        for data_idx in tqdm(range(tot)):
            datum = dataset['data'][data_idx]['paragraphs'][0]
            context_str = datum['context']
            _datum = {
                'context': context_str,
                'title': dataset['data'][data_idx]['title'],
                'id': data_idx
            }

            nlp_context = nlp(pre_proc(context_str))
            _datum['annotated_context'] = self.process(nlp_context)
            _datum['raw_context_offsets'] = self.get_raw_context_offsets(
                _datum['annotated_context']['word'], context_str)
            _datum['qas'] = []

            # assert len(datum['qas']['questions']) == len(datum['answers'])

            for i in range(len(datum['qas'])):
                question, answer = datum['qas'][i]['question'], datum['qas'][
                    i]['answers'][0]['text']
                # assert question['turn_id'] == answer['turn_id']
                count += 1
                idx = datum['qas'][i]['id']
                _qas = {'turn_id': idx, 'question': question, 'answer': answer}

                _qas['annotated_question'] = self.process(
                    nlp(pre_proc(question)))

                _qas['annotated_answer'] = self.process(nlp(pre_proc(answer)))
                _qas['raw_answer'] = answer
                _qas['answer_type'] = "extractive"
                _qas['answer_span_start'] = datum['qas'][i]['answers'][0][
                    'answer_start']
                _qas['answer_span_end'] = _qas['answer_span_start'] + len(
                    answer) + 1
                _qas['followup'] = datum['qas'][i]['followup']
                _qas['yesno'] = datum['qas'][i]['yesno']

                tmp = _qas['raw_answer']
                tmp = self.removePunctuation(tmp)
                if _qas['raw_answer'] in context_str or tmp.lower() in [
                        "yes", "no", "unknown"
                ]:
                    type1 += 1
                    _qas['answer_type'] = "extractive"
                else:
                    type2 += 1
                    _qas['answer_type'] = "generative"

                start = _qas['answer_span_start']  # rational 范围
                end = _qas['answer_span_end']
                chosen_text = _datum['context'][start:end].lower()
                while len(chosen_text) > 0 and chosen_text[
                        0] in string.whitespace:  # 判断开头的空白符 \t,\n等6种
                    chosen_text = chosen_text[1:]
                    start += 1
                while len(chosen_text) > 0 and chosen_text[
                        -1] in string.whitespace:  # 判断结尾的空白符
                    chosen_text = chosen_text[:-1]
                    end -= 1
                input_text = _qas['answer'].strip().lower()
                if input_text in chosen_text:
                    p = chosen_text.find(input_text)  # p:input_text的起始值
                    _qas['answer_span'] = self.find_span(
                        _datum['raw_context_offsets'], start + p,
                        start + p + len(input_text))
                else:
                    _qas['answer_span'] = self.find_span_with_gt(
                        _datum['context'], _datum['raw_context_offsets'],
                        input_text)

                _datum['qas'].append(_qas)

            data.append(_datum)

        # build vocabulary
        if dataset_label == 'train':
            print('Build vocabulary from training data...')
            contexts = [_datum['annotated_context']['word'] for _datum in data]
            qas = [
                qa['annotated_question']['word'] +
                qa['annotated_answer']['word'] for qa in _datum['qas']
                for _datum in data
            ]
            # self.train_vocab = self.build_vocab(contexts, qas)

        # print('Getting word ids...')
        # w2id = {w: i for i, w in enumerate(self.train_vocab)}
        # for _datum in data:
        #     _datum['annotated_context']['wordid'] = token2id_sent(_datum['annotated_context']['word'], w2id, unk_id=1,
        #                                                           to_lower=False)
        #     # new modify, get wordid
        #     for qa in _datum['qas']:
        #         qa['annotated_question']['wordid'] = token2id_sent(qa['annotated_question']['word'], w2id, unk_id=1,
        #                                                            to_lower=False)
        #         qa['annotated_answer']['wordid'] = token2id_sent(qa['annotated_answer']['word'], w2id, unk_id=1,
        #                                                          to_lower=False)

        # if dataset_label == 'train':
        #     # get the condensed dictionary embedding
        #     print('Getting embedding matrix for ' + dataset_label)
        #     embedding = build_embedding(self.glove_file, self.train_vocab, self.glove_dim)
        #     meta = {'vocab': self.train_vocab, 'embedding': embedding.tolist()}
        #     meta_file_name = os.path.join(self.spacyDir, dataset_label + '_meta.msgpack')
        #     print('Saving meta information to', meta_file_name)
        #     with open(meta_file_name, 'wb') as f:
        #         msgpack.dump(meta, f, encoding='utf8')

        dataset['data'] = data

        if dataset_label == 'test':
            return dataset

        with open(output_file_name, 'w') as output_file:
            json.dump(dataset, output_file, sort_keys=True, indent=4)
        print("The amount of extractive qa is: ", type1)
        print("The amount of generative qa is: ", type2)
        print("The amount of qas is: ", count)