def GetEdgeFromCoocc():
    global datalist, datadict, idf, tags, r1cnt
    datalist = []
    datadict = {}
    df = defaultdict(int)
    for jj in ljqpy.LoadList('training/all_data.txt'):
        jj = json.loads(jj)
        datadict[jj['id']] = jj['text']
        tf = GetTags(jj['text'])
        for t in tf.keys():
            df[t] += 1
        jj['tf'] = tf
    N = len(datalist)
    idf = {x: math.log(N / s) for x, s in df.items()}
    #ljqpy.SaveCSV(ljqpy.FreqDict2List(idf), 'saved_graph/idf.txt')
    tags = {x for x, s in df.items() if s > 2 and idf[x] > 2 and len(x) > 1}
    tags = {x for x in tags if not x.isdigit()}

    print('docu segs:', N)
    print('tags:', len(tags))
    lasttts = []

    r2cnt = defaultdict(int)
    r1cnt = defaultdict(int)

    for i, jj in enumerate(datalist):
        id, words = jj['id'], jj['tf']
        tt = [x for x in words.keys() if x in tags]
        if i % 1000 == 0: print('datalist %d/%d' % (i, len(datalist)))

        for mi in range(3):
            if mi >= i: continue
            lid, lasttt = (id, tt) if i == 0 else lasttts[-mi]
            if lid.split('@')[0] != id.split('@')[0]: break

            for w1 in tt:
                for w2 in lasttt:
                    if w1 in w2 or w2 in w1: continue
                    if w2 < w1: w1, w2 = w2, w1
                    r2cnt[(w1, w2)] += 1
                    r1cnt[w1] += 1
                    r1cnt[w2] += 1

        lasttts.append((id, tt))
        if len(lasttts) > 10: lasttts = lasttts[5:]

    relscs = {}
    for g2, ng2 in ljqpy.FreqDict2List(r2cnt):
        for i, w in enumerate(g2):
            relscs[(w, g2[1 - i])] = ng2 / r1cnt[w]
        #print(g2, ng2, ng2/r1cnt[g2[0]], ng2/r1cnt[g2[1]])
        if ng2 < 100: break

    with open('gen_rels/edges_coocc.txt', 'w', encoding='utf-8') as fout:
        for g2, rel in ljqpy.FreqDict2List(relscs):
            if rel < 0.2: break
            ljqpy.WriteLine(fout, ['coocc', g2[0], g2[1], rel])
def MakeS2SDict(fn=None, min_freq=5, delimiter=' ', dict_file=None):
    if dict_file is not None and os.path.exists(dict_file):
        print('loading', dict_file)
        lst = ljqpy.LoadList(dict_file)
        midpos = lst.index('<@@@>')
        itokens = TokenList(lst[:midpos])
        otokens = TokenList(lst[midpos + 1:])
        return itokens, otokens
    data = ljqpy.LoadCSV(fn)
    wdicts = [{}, {}]
    for ss in data:
        for seq, wd in zip(ss, wdicts):
            for w in seq.split(delimiter):
                wd[w] = wd.get(w, 0) + 1
    wlists = []
    for wd in wdicts:
        wd = ljqpy.FreqDict2List(wd)
        wlist = [x for x, y in wd if y >= min_freq]
    print('seq 1 words:', len(wlists[0]))
    print('seq 2 words:', len(wlists[1]))
    itokens = TokenList(wlists[0])
    otokens = TokenList(wlists[1])
    if dict_file is not None:
        ljqpy.SaveList(wlists[0] + ['<@@@>'] + wlists[1], dict_file)
    return itokens, otokens
def MakeVocab():
    global id2w, w2id, id2c, c2id
    vocabFile = 'data/wordlist.txt'
    charFile = 'data/charlist.txt'
    if os.path.exists(vocabFile):
        freqw = ljqpy.LoadCSV(vocabFile)
        freqc = ljqpy.LoadCSV(charFile)
        freqw = {}
        freqc = {}
        for line in ljqpy.LoadCSVg(trainFile):
            line = ''.join(line)
            thisJson = json.loads(line.strip().lower())
            question = thisJson["query"]
            question = re.sub(r'\s+', ' ', question.strip())
            questionTokens = CutSentence(question)
            for t in questionTokens:
                for c in t:
                    freqc[c] = freqc.get(c, 0) + 10
                t = ChangeToken(t)
                freqw[t] = freqw.get(t, 0) + len(thisJson["passages"])
            for passage in thisJson["passages"]:
                context = passage["passage_text"]
                context = FullToHalf(context)
                context = re.sub(r'\s+', ' ', context.strip())
                contextTokens = CutSentence(context)
                for t in contextTokens:
                    for c in t:
                        freqc[c] = freqc.get(c, 0) + 1
                    t = ChangeToken(t)
                    freqw[t] = freqw.get(t, 0) + 1
        freqw = ljqpy.FreqDict2List(freqw)
        ljqpy.SaveCSV(freqw, vocabFile)
        freqc = ljqpy.FreqDict2List(freqc)
        ljqpy.SaveCSV(freqc, charFile)
    id2w = ['<PAD>', '<UNK>'] + [x[0] for x in freqw[:vocab_size]]
    w2id = {y: x for x, y in enumerate(id2w)}
    id2c = ['<PAD>', '<UNK>'] + [x[0] for x in freqc[:char_size]]
    c2id = {y: x for x, y in enumerate(id2c)}
def MakeS2SDict(fn=None, min_freq=5, delimiter=' ', dict_file=None):
	构建input和output sequence的 word或char list
	:param fn: 
	:param min_freq: 
	:param delimiter: 
	:param dict_file: 
	# 如果有word/char list则不需要重新构建
	if dict_file is not None and os.path.exists(dict_file):
		print('loading', dict_file)
		lst = ljqpy.LoadList(dict_file)
		midpos = lst.index('<@@@>')
		itokens = TokenList(lst[:midpos])
		otokens = TokenList(lst[midpos+1:])
		return itokens, otokens
	# 如果没有则重新构建
	data = ljqpy.LoadCSV(fn)
	wdicts = [{}, {}]
	for ss in data:
		for seq, wd in zip(ss, wdicts):
			for w in seq.split(delimiter): 
				wd[w] = wd.get(w, 0) + 1  # nice code
	wlists = []
	for wd in wdicts:	
		wd = ljqpy.FreqDict2List(wd)
		wlist = [x for x,y in wd if y >= min_freq]
	print('seq 1 words:', len(wlists[0]))
	print('seq 2 words:', len(wlists[1]))
	itokens = TokenList(wlists[0])
	otokens = TokenList(wlists[1])
	if dict_file is not None:
		ljqpy.SaveList(wlists[0]+['<@@@>']+wlists[1], dict_file)
	return itokens, otokens
    return datax, datay

datadir = '../dataset/chsner_char-level'
xys = [
    LoadCoNLLFormat(os.path.join(datadir, '%s.txt') % tp)
    for tp in ['train', 'test']

id2y = {}
for yy in xys[0][1]:
    for y in yy:
        id2y[y] = id2y.get(y, 0) + 1
id2y = [x[0] for x in ljqpy.FreqDict2List(id2y)]
y2id = {v: k for k, v in enumerate(id2y)}

def convert_data(df):
    text = [' '.join(t[:max_seq_len]) for t in df[0]]
    label = [[0] + [y2id.get(x, 0) for x in t[:max_seq_len - 1]]
             for t in df[1]]
    return text, label

(train_text, train_label), (test_text, test_label) = map(convert_data, xys)

bert_tl = bt.ALBERTLayer(lang='cn')

    def gen_new_tags(self, corpusfn, numlim=1000):
        global ng1, ng2, ng3, pg1, pg2, pg3, pdict, ndict, scores

        def _HH(p):
            return -p * math.log(p) if p > 0 else 0

        def _HY(g3, g2):
            return _HH(ng3[g3] / ng2[g2])

        ng1 = defaultdict(int)
        ng2 = defaultdict(int)
        ng3 = defaultdict(int)
        pdict, ndict = {}, {}
        cnum = 0
        for ii, lines in enumerate(ljqpy.LoadCSVg(corpusfn)):
            line = lines[0]
            if ii % 100000 == 0: print('counting', ii)
            if line == '': continue
            if len(line) < 10: continue
            if re.search('[a-zA-Z\u4e00-\u9fa5]{2,}', line) is None: continue
            lln = jieba.lcut(line)
            lln = ['^'] + lln + ['$']
            for i, wd in enumerate(lln):
                ng1[wd] += 1
                if i > 0: ng2[tuple(lln[i - 1:i + 1])] += 1
                if i > 1: ng3[tuple(lln[i - 2:i + 1])] += 1
                if i > 1:
                    pdict.setdefault(tuple(lln[i - 1:i + 1]),
                                     set()).add(lln[i - 2])
                    ndict.setdefault(tuple(lln[i - 2:i]), set()).add(lln[i])
            cnum += len(lln)
        log_all_ng1 = math.log(sum(ng1.values()))
        log_all_ng2 = math.log(sum(ng2.values()))
        log_all_ng3 = math.log(sum(ng3.values()))
        pg1 = {k: math.log(v) - log_all_ng1 for k, v in ng1.items()}
        pg2 = {k: math.log(v) - log_all_ng2 for k, v in ng2.items()}
        pg3 = {k: math.log(v) - log_all_ng3 for k, v in ng3.items()}
        print('COUNT ok')

        # base_wp = {x:float(y) for x,y in ljqpy.LoadCSV('resources/base_wcounts.txt')}
        # pg1 = {k:(log_sum_exp([base_wp[k],v])-math.log(2) if k in base_wp else v) for k,v in pg1.items()}

        scores = {}
        ii = 0
        for k, v in ljqpy.FreqDict2List(pg2):
            ii += 1
            if ii % 10000 == 0: print('%d/%d' % (ii, len(pg2)))
            if max(ng1[k[0]], ng1[k[1]]) <= 3: continue
            pmi = v - pg1[k[0]] - pg1[k[1]]
            if pmi < 2: continue
            Hl, Hr = 0, 0
            Hlr, Hrl = 0, 0
            for ll in pdict.get(k, []):
                Hl += _HY((ll, k[0], k[1]), k)
                Hlr += _HY((ll, k[0], k[1]), (ll, k[0]))
            for rr in ndict.get(k, []):
                Hr += _HY((k[0], k[1], rr), k)
                Hrl += _HY((k[0], k[1], rr), (k[1], rr))
            score = pmi - min(Hlr, Hrl) + min(Hl, Hr)
            if not ljqpy.IsChsStr(k[0] + k[1]): continue
            scores[k] = score * ng2[k]

        phrases = []
        for k, v in ljqpy.FreqDict2List(scores)[:numlim]:
            print(k, v)
        self.newtags = phrases
        self.newtagtrie = Trie({x: 1 for x in self.newtags})
        return phrases