def searchTable3(sample=1, enhance=1):
    dz_train = open('dz_pk_cls_table.train', 'a+')
    dz_dev = open('dz_pk_cls_table.dev', 'a+')

    dingzengs = getDingZengUnion(dz_trainpath)
    for id in dingzengs.keys():
        dzs = dingzengs[id]
        htmlpath = dz_htmlpath + id + '.html'

        rank = 6
        mod = int(id) % rank
        if mod == 0:
            makefile = dz_dev
        else:
            makefile = dz_train

        soup = BeautifulSoup(open(htmlpath), 'lxml')
        tables = soup.find_all('table')
        for table in tables:  # 遍历所有table
            cuts = table2array(table)  # 将table转为二维数组
            for cut in cuts:  # 遍历规整行列数组
                rows = len(cut)
                cols = len(cut[0])
                for row in range(rows):
                    for col in range(cols):

                        valuecell = cut[row][col]
                        topcell = cut[0][col]
                        leftcell = cut[row][0]

                        if matchDuixiang(topcell) or matchDuixiang(leftcell):
                            label = '__label__nothing '
                            for dz in dzs:
                                if hasPK(valuecell, dz.duixiang, 'DX', dz):
                                    label = '__label__dzpk '
                                    break

                            if row == 0 and col == 0:
                                pass
                            elif row == 0 and col != 0:
                                valuecell = leftcell + valuecell
                            elif row != 0 and col == 0:
                                valuecell = topcell + valuecell
                            else:
                                valuecell = topcell + leftcell + valuecell

                            toline = label + ' '.join(jieba_tokenize(valuecell)) + '\n'

                            for j in range(enhance):
                                if label != '__label__nothing ':
                                    for i in range(sample):  # 干涉正负样本比例
                                        makefile.write(toline)
                                else:
                                    makefile.write(toline)

        for dz in dzs:
            dz.desc()
def maketrain_word_debug(id, before=0):
    hetongs = getHeTongUnion(dz_trainpath)
    hts = hetongs[id]
    htmlpath = dz_htmlpath + id + '.html'
    sentences = levelText_without_table(htmlpath)
    for sid in range(len(sentences)):
        sentence = sentences[sid]
        beforetext = ''
        for i in range(sid - before, sid):
            if i >= 0:
                beforetext += sentences[i]
        context = beforetext + sentence
        cws_cut_cent = jieba_tokenize(context)
        labels = set()
        tag_arr = ['O'] * len(cws_cut_cent)
        for ht in hts:
            if mask_contract_field_word(context, cws_cut_cent, ht.jiafang,
                                        tag_arr, 'JF', ht, False):
                labels.add('__label__jf')
            if mask_contract_field_word(context, cws_cut_cent, ht.yifang,
                                        tag_arr, 'YF', ht, False):
                labels.add('__label__yf')
            if mask_contract_field_word(context, cws_cut_cent, ht.xiangmu,
                                        tag_arr, 'XM', ht, False):
                labels.add('__label__xm')
            if mask_contract_field_word(context, cws_cut_cent, ht.hetong,
                                        tag_arr, 'HT', ht, False):
                labels.add('__label__ht')
            if mask_contract_field_word(context, cws_cut_cent, ht.amount_u,
                                        tag_arr, 'AU', ht, True):
                labels.add('__label__amount')
            if mask_contract_field_word(context, cws_cut_cent, ht.amount_d,
                                        tag_arr, 'AD', ht, True):
                labels.add('__label__amount')
            lhs = ht.lianhe.split('、')

            for lh in lhs:
                if mask_contract_field_word(context, cws_cut_cent, lh, tag_arr,
                                            'LH', ht, False):
                    labels.add('__label__lht')

    for ht in hts:
        ht.desc()
def predict(model, text):
    text = ' '.join(jieba_tokenize(text))
    pre = model.predict(text)
    return pre[0][0], pre[1][0]
def maketrain_word(before=0):
    rank = 7
    ht_train = open(dataroot + 'ht_all_text_word.train', 'a+')
    ht_dev = open(dataroot + 'ht_all_text_word.dev', 'a+')
    ht_test = open(dataroot + 'ht_all_text_word.test', 'a+')
    cls_train = open('ht_cls_text_word.train', 'a+')
    cls_dev = open('ht_cls_text_word.dev', 'a+')
    hetongs = getHeTongUnion(dz_trainpath)
    for id in hetongs.keys():
        mod = int(id) % rank
        if mod == 0:
            nerfile = ht_dev
            clsfile = cls_dev
        elif mod == 1:
            nerfile = ht_test
            clsfile = cls_dev
        else:
            nerfile = ht_train
            clsfile = cls_train
        hts = hetongs[id]
        htmlpath = dz_htmlpath + id + '.html'
        sentences = levelText_without_table(htmlpath)
        for sid in range(len(sentences)):
            sentence = sentences[sid]
            beforetext = ''
            for i in range(sid - before, sid):
                if i >= 0:
                    beforetext += sentences[i]
            context = beforetext + sentence
            cws_cut_cent = jieba_tokenize(context)
            labels = set()
            tag_arr = ['O'] * len(cws_cut_cent)
            isMask = False
            for ht in hts:
                if mask_contract_field_word(context, cws_cut_cent, ht.jiafang,
                                            tag_arr, 'JF', ht, False):
                    labels.add('__label__jf')
                    isMask = True
                if mask_contract_field_word(context, cws_cut_cent, ht.yifang,
                                            tag_arr, 'YF', ht, False):
                    labels.add('__label__yf')
                    isMask = True
                if mask_contract_field_word(context, cws_cut_cent, ht.xiangmu,
                                            tag_arr, 'XM', ht, False):
                    labels.add('__label__xm')
                    isMask = True
                if mask_contract_field_word(context, cws_cut_cent, ht.hetong,
                                            tag_arr, 'HT', ht, False):
                    labels.add('__label__ht')
                    isMask = True
                if mask_contract_field_word(context, cws_cut_cent, ht.amount_u,
                                            tag_arr, 'AU', ht, True):
                    labels.add('__label__amount')
                    isMask = True
                if mask_contract_field_word(context, cws_cut_cent, ht.amount_d,
                                            tag_arr, 'AD', ht, True):
                    labels.add('__label__amount')
                    isMask = True
                lhs = ht.lianhe.split('、')

                # if filter_lh(context):
                for lh in lhs:
                    if mask_contract_field_word(context, cws_cut_cent, lh,
                                                tag_arr, 'LH', ht, False):
                        labels.add('__label__lht')
                        isMask = True
            if True:
                # for i in range(len(context)):
                #     # nerfile.write(str(i + 1) + '\t' + context[i] + '\t' + tag_arr[i] + '\n')
                #     nerfile.write(context[i] + ' ' + tag_arr[i] + '\n')
                # nerfile.write('\n')

                for i in range(len(cws_cut_cent)):
                    nerfile.write(
                        str(i + 1) + '\t' + cws_cut_cent[i] + '\t' +
                        tag_arr[i] + '\n')
                nerfile.write('\n')
            else:
                labels.add('__label__nothing')

            # toline = ' '.join(labels) + ' ' + ' '.join(jieba_tokenize(context)) + '\n'
            # clsfile.write(toline)

        for ht in hts:
            ht.desc()
def searchTable3(sample=1, enhance=1):
    dz_train = open('dz_att_cls_table.train', 'a+')
    dz_dev = open('dz_att_cls_table.dev', 'a+')

    dingzengs = getDingZengUnion(dz_trainpath)
    for id in dingzengs.keys():
        dzs = dingzengs[id]
        htmlpath = dz_htmlpath + id + '.html'

        rank = 6
        mod = int(id) % rank
        if mod == 0:
            makefile = dz_dev
        else:
            makefile = dz_train

        soup = BeautifulSoup(open(htmlpath), 'lxml')
        tables = soup.find_all('table')
        for table in tables:  # 遍历所有table
            cuts = table2array(table)  # 将table转为二维数组
            for cut in cuts:  # 遍历规整行列数组
                rows = len(cut)
                cols = len(cut[0])
                for row in range(rows):
                    for col in range(cols):

                        valuecell = cut[row][col]
                        topcell = cut[0][col]
                        leftcell = cut[row][0]

                        # if matchDuixiang(topcell) or matchDuixiang(leftcell):
                        labels = set()
                        # for dz in dzs:
                        #     if hasAtt(valuecell, dz.duixiang, 'DX', dz, False, None, None):
                        #         labels.add('__label__dzdx')
                        #     if hasAtt(valuecell, dz.shuliang, 'SL', dz, True, topcell, leftcell):
                        #         labels.add('__label__dzsl')
                        #     if hasAtt(valuecell, dz.jine, 'JE', dz, True, topcell, leftcell):
                        #         labels.add('__label__dzje')

                        for dz in dzs:
                            if hasAtt(valuecell, dz.duixiang, 'DX', dz, False, None, None, False) \
                                    and not hasAtt(topcell + leftcell, dz.duixiang, 'DX', dz, False, None, None, False):
                                labels.add('__label__dzdx')
                            if hasAtt(valuecell, dz.shuliang, 'SL', dz, True, topcell, leftcell, False) \
                                    and not hasAtt(topcell + leftcell, dz.shuliang, 'SL', dz, True, topcell, leftcell, False):
                                labels.add('__label__dzsl')
                            if hasAtt(valuecell, dz.jine, 'JE', dz, True, topcell, leftcell, False) \
                                    and not hasAtt(topcell + leftcell, dz.shuliang, 'SL', dz, True, topcell, leftcell, False):
                                labels.add('__label__dzje')

                        if len(labels) == 0:
                            labels.add('__label__nothing')

                        if row == 0 and col == 0:
                            pass
                        elif row == 0 and col != 0:
                            valuecell = leftcell + valuecell
                        elif row != 0 and col == 0:
                            valuecell = topcell + valuecell
                        else:
                            valuecell = topcell + leftcell + valuecell

                        toline = ' '.join(labels) + ' ' + ' '.join(jieba_tokenize(valuecell)) + '\n'
                        makefile.write(toline)

        for dz in dzs:
            dz.desc()