예제 #1
0
파일: tools.py 프로젝트: eastonYi/asr-tf1
def get_dataset_ngram(text_file, n, k, savefile=None, split=5000):
    """
    Simply concatenate all sents into one will bring in noisy n-gram at end of each sent.
    Here we count ngrams for each sent and sum them up.
    """
    from utils.dataProcess import get_N_gram
    from nltk import FreqDist

    def iter_in_sent(sent):
        for word in sent.split():
            yield word

    print('analysing text ...')

    list_utterances = open(text_file).readlines()

    ngrams_global = FreqDist()
    for i in range(len(list_utterances) // split + 1):
        ngrams = FreqDist()
        text = list_utterances[i * split:(i + 1) * split]
        for utt in tqdm(text):
            _, seq_label, _ = utt.strip().split(',')
            ngram = get_N_gram(iter_in_sent(seq_label), n)
            ngrams += ngram

        ngrams_global += dict(ngrams.most_common(2 * k))

    if savefile:
        with open(savefile, 'w') as fw:
            for ngram, num in ngrams_global.most_common(k):
                line = '{}:{}'.format(ngram, num)
                fw.write(line + '\n')

    return ngrams_global
예제 #2
0
def get_preds_ngram(preds, len_preds, n):
    """
    Simply concatenate all sents into one will bring in noisy n-gram at end of each sent.
    Here we count ngrams for each sent and sum them up.
    """
    def iter_preds(preds, len_preds):
        for len, utt in zip(len_preds, preds):
            for token in utt[:len]:
                yield token.numpy()

    ngrams = get_N_gram(iter_preds(preds, len_preds), n)

    return ngrams