Ejemplo n.º 1
0
def split_doc2(data_path, out_path):
    import re
    data = load_txt_data(data_path)
    doc_index = 0
    for i in tqdm(range(len(data))):
        try:
            line = data[i].split(',')
            if len(line[0]) < 100:
                continue
            abstract = re.sub("[\" ]", "", line[1])
            abstract = ' '.join(abstract)
            tmp = re.sub("[\" ]", "", line[0])
            tmp = tmp.split('。')
            document = []
            for x in tmp:
                document.append(' '.join(x))
        except IndexError:
            continue

        # print(document)
        for j in range(len(document)):
            document[j] = document[j] + '\n'
        new_doc = document + ['@highlight\n'] + [abstract]
        save_txt_file(new_doc, out_path + str(doc_index) + '.story')
        doc_index += 1
Ejemplo n.º 2
0
def filter_data(path):
    data = load_txt_data(path)
    res = []
    for item in tqdm(data, desc='Filter'):
        raw = item.split(',')
        doc = raw[1]
        abst = raw[0]
        if len(doc) >= 100:
            res.append('{},{}'.format(abst, doc))
    save_txt_file(res, path)
Ejemplo n.º 3
0
def revers_index(path):
    data = load_txt_data(path)
    res = []
    for item in data:
        raw = item.split(',')
        doc = raw[0]
        try:
            abst = raw[1]
        except IndexError:
            continue
        res.append('{},{}'.format(abst, doc))
    save_txt_file(res, path)
Ejemplo n.º 4
0
def split_doc(data_path, out_path):
    data = load_txt_data(data_path)
    doc_index = 0
    for i in tqdm(range(len(data)), desc='split_doc'):
        line = data[i].split(',')
        abstract = " ".join(line[0])
        from pyparsing import oneOf
        punc = oneOf(list("。,;;!?"))
        document = [' '.join(x) for x in punc.split(line[1])]
        # print(document)
        for j in range(len(document)):
            document[j] = document[j] + '\n'
        new_doc = document + ['@highlight\n'] + [abstract]
        _doc_index = str(doc_index)
        while len(_doc_index) <= 8:
            _doc_index = '0' + _doc_index
        save_txt_file(new_doc, out_path + _doc_index + '.story')
        doc_index += 1
Ejemplo n.º 5
0
    def predict(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate models.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set models in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if not cal_lead and not cal_oracle:
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path +
                                            self.args.data_name, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path +
                                        self.args.data_name, step)
        origin_path = '%s_step%d.origin' % (self.args.result_path +
                                            self.args.data_name, step)
        with open(can_path, 'w', encoding='utf-8') as save_pred:
            with open(gold_path, 'w', encoding='utf-8') as save_gold:
                with torch.no_grad():
                    origin = []
                    for batch in test_iter:

                        src = batch.src  # 7 sentences
                        # logger.info('origin sent: %s' % len(batch.src_str))  # 7 sentences

                        labels = batch.labels
                        segs = batch.segs
                        clss = batch.clss
                        mask = batch.mask
                        mask_cls = batch.mask_cls

                        gold = []
                        pred = []

                        if cal_lead:
                            selected_ids = [list(range(batch.clss.size(1)))
                                            ] * batch.batch_size
                        elif cal_oracle:
                            selected_ids = [[
                                j for j in range(batch.clss.size(1))
                                if labels[i][j] == 1
                            ] for i in range(batch.batch_size)]
                        else:
                            sent_scores, mask = self.model(
                                src, segs, clss, mask, mask_cls)

                            loss = self.loss(sent_scores, labels.float())
                            loss = (loss * mask.float()).sum()
                            batch_stats = Statistics(
                                float(loss.cpu().data.numpy()), len(labels))
                            stats.update(batch_stats)

                            sent_scores = sent_scores + mask.float()
                            sent_scores = sent_scores.cpu().data.numpy()
                            selected_ids = np.argsort(-sent_scores, 1)

                        # selected_ids = np.sort(selected_ids,1)

                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if len(batch.src_str[i]) == 0:
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if j >= len(batch.src_str[i]):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                if self.args.block_trigram:
                                    if not _block_tri(candidate, _pred):
                                        _pred.append(candidate)
                                        # print(candidate)
                                else:
                                    _pred.append(candidate)

                                if (not cal_oracle) and (
                                        not self.args.recall_eval
                                ) and len(_pred) == 3:
                                    break
                            # exit()
                            _pred = '<q>'.join(_pred)
                            # logger.info('pred sent: %s' % (_pred))
                            if self.args.recall_eval:
                                _pred = ' '.join(
                                    _pred.split()
                                    [:len(batch.tgt_str[i].split())])
                                # _src = ' '.join()
                            # logger.info('origin sent: %s' % (batch.src_str[i]))
                            # logger.info('pred sent: %s' % (_pred))
                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])
                            _origin = ' '.join(batch.src_str[i])
                            if self.args.vy_predict:
                                doc_id = batch.doc_id
                                _origin = str(doc_id[i]) + '\t' + _origin
                            origin.append(_origin)
                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip() + '\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip() + '\n')
                    save_txt_file(origin, origin_path)

        if step != -1 and self.args.report_rouge:
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' %
                        (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
Ejemplo n.º 6
0
def merge_files(path_list, merge_path):
    data = []
    for path in path_list:
        data += load_txt_data(path)
    save_txt_file(data, merge_path)
Ejemplo n.º 7
0
        return new_paragraph

    def count_info(self, data):
        """
        返回有用的文字数量
        :param data:
        :return:
        """
        i = 0
        for char in data:
            if char not in self.puncs:
                i += 1
        return i


if __name__ == '__main__':
    _a = SegmentationData('../data/segmentation_corpus/raw/',
                          '../utils/config/punctuation.dat',
                          '../utils/config/segmentation.dat')
    #
    # for item in _a.raw_data:
    #     # print([item])
    #     pass
    # print(len(_a.raw_data))
    # print(_a.max_seq)
    _punc_data = _a.punc_data
    save_txt_file(_punc_data,
                  '../data/segmentation_corpus/punc_data.train.txt')
    # save_txt_file(_a.seg_data, '../data/segmentation_corpus/segment_data.train.txt', end='')