def rank_ngrams(dialogue_corpus, ngrams, redis_port, steps_per_log=100000):
    if not ngrams:
        raise ValueError('ngrams is required')

    ngrams_cache = {}
    redis = TinyRedis(port=redis_port, max_connections=1000)

    for ngram in ngrams:
        ngrams_cache[ngram] = (redis, defaultdict(int))

    dir, fname, _ = fs.split3(dialogue_corpus.data_path)

    sw = Stopwatch()
    with codecs.getreader("utf-8")(open(dialogue_corpus.data_path,
                                        mode="rb")) as data_file:
        lno = 0
        for line in data_file:
            lno += 1

            if lno % steps_per_log == 0:
                if ngrams_cache:
                    for ngram in ngrams_cache:
                        logger.info('{}-grams flushing {} keys...'.format(
                            ngram, len(ngrams_cache[ngram][1])))
                        ngrams_cache[ngram][0].pl_hincrby(
                            '{}#{}'.format(ngram, fname),
                            ngrams_cache[ngram][1])
                        ngrams_cache[ngram] = (ngrams_cache[ngram][0],
                                               defaultdict(int))
                    logger.info('ngrams flushed...')

                logger.info('{} lines processed - time {}'.format(
                    lno, sw.elapsed()))

            for t, utter in enumerate(line.split(
                    dialogue_corpus.utterance_sep)):
                tokens = utter.split()
                for ngram in ngrams_cache:
                    for i in range(len(tokens) - ngram):
                        ngrams_cache[ngram][1][' '.join(tokens[i:i +
                                                               ngram])] += 1

    for ngram in ngrams_cache:
        logger.info('{}-grams flushing {} keys...'.format(
            ngram, len(ngrams_cache[ngram][1])))
        ngrams_cache[ngram][0].pl_hincrby('{}#{}'.format(ngram, fname),
                                          ngrams_cache[ngram][1])
        ngrams_cache[ngram] = (ngrams_cache[ngram][0], defaultdict(int))
    logger.info('ngrams flushed...')

    for ngram in ngrams_cache:
        with codecs.getwriter("utf-8")(open(
                os.path.join(dir, '{}.{}grams'.format(fname, ngram)),
                'wb')) as ngram_file:
            for ngram_str, freq in ngrams_cache[ngram][0].hscan('{}#{}'.format(
                    ngram, fname)):
                if int(freq) > 10:
                    ngram_file.write('{}\t{}\n'.format(ngram_str, freq))

    next(iter(ngrams_cache.values()))[0].close()
def count_ngrams(dialogue_corpus, ngrams, redis_port, steps_per_log=5000):
    if not ngrams:
        raise ValueError('ngrams is required')

    _, fname, _ = fs.split3(dialogue_corpus.data_path)

    def get_key(n):
        return '{}__c{}'.format(fname, n)

    ngrams_cache = {}
    n_words = 0
    with TinyRedis(port=redis_port, max_connections=100) as redis:

        for ngram in ngrams:
            redis.delete(get_key(ngram))
            ngrams_cache[ngram] = set()

        sw = Stopwatch()
        with codecs.getreader("utf-8")(open(dialogue_corpus.data_path,
                                            mode="rb")) as data_file:
            lno = 0
            for line in data_file:
                lno += 1

                if lno % steps_per_log == 0:
                    if ngrams_cache:
                        for ngram in ngrams_cache:
                            logger.info('{}-grams flushing {} keys...'.format(
                                ngram, len(ngrams_cache[ngram])))
                            redis.pfadd(get_key(ngram), *ngrams_cache[ngram])
                            ngrams_cache[ngram] = set()
                        logger.info('ngrams flushed...')

                    logger.info('{} lines processed - time {}'.format(
                        lno, sw.elapsed()))

                for utter in line.split(dialogue_corpus.utterance_sep):
                    tokens = utter.split()
                    n_words += len(tokens)
                    for ngram in ngrams_cache:
                        for i in range(len(tokens) - ngram):
                            ngrams_cache[ngram].add(' '.join(tokens[i:i +
                                                                    ngram]))

        for ngram in ngrams_cache:
            logger.info('{}-grams flushing {} keys...'.format(
                ngram, len(ngrams_cache[ngram])))
            redis.pfadd(get_key(ngram), *ngrams_cache[ngram])
        logger.info('ngrams flushed...')

        logger.info('**** {} ****'.format(fname))
        logger.info("#words = {}".format(n_words))
        for ngram in ngrams:
            ngram_cnt = redis.pfcount(get_key(ngram))
            logger.info('# {}-grams = {} | distinct-{} = {:.3f}'.format(
                ngram, ngram_cnt, ngram, ngram_cnt / n_words))
Exemple #3
0
    def _pre_model_creation(self):
        self.config['topic_vocab_file'] = os.path.join(
            fs.split3(self.config.vocab_file)[0], 'topic_vocab.in')
        self._vocab_table, self.__topic_vocab_table = topical_base.initialize_vocabulary(
            self.config)

        EmbeddingUtil(self.config.embed_conf).build_if_not_exists(
            self.config.embedding_type, self.config.vocab_h5,
            self.config.vocab_file)

        if 'original_vocab_size' not in self.config:
            self.config['original_vocab_size'] = self.config.vocab_size
        self.config.vocab_size = len(self._vocab_table)
        self.config.topic_vocab_size = len(self.__topic_vocab_table)

        if self.config.mode == "interactive" and self.config.lda_model_dir is None:
            raise ValueError(
                "In interactive mode, TA-Seq2Seq requires a pretrained LDA model"
            )
def analyze(dialogue_corpus, analysis_args, steps_per_log=100000):
    dir, fname, _ = fs.split3(dialogue_corpus.data_path)

    if analysis_args.n_frequent_words > 0:
        frequent_words_path = os.path.join(
            dir, '{}.top{}'.format(fname, analysis_args.n_frequent_words))
        frequent_words_file = codecs.getwriter("utf-8")(open(
            frequent_words_path, mode="wb"))
    else:
        frequent_words_file = None

    if analysis_args.n_rare_words > 0:
        rare_words_path = os.path.join(
            dir, '{}.bottom{}'.format(fname, analysis_args.n_rare_words))
        rare_words_file = codecs.getwriter("utf-8")(open(rare_words_path,
                                                         mode="wb"))
    else:
        rare_words_file = None

    if analysis_args.save_tf:
        tf_path = os.path.join(dir, '{}.tf'.format(fname))
        tf_file = codecs.getwriter("utf-8")(open(tf_path, mode="wb"))
    else:
        tf_file = None

    tf_dict = defaultdict(int)

    lno, uno = 0, 0
    wno_stat_per_turn = []
    wno_stat = SampledSummaryStat()
    sw = Stopwatch()

    with codecs.getreader("utf-8")(open(dialogue_corpus.data_path,
                                        mode="rb")) as data_file:
        for line in data_file:
            lno += 1

            if lno % steps_per_log == 0:
                logger.info(
                    '{} lines processed - so far vocab {} - time {}'.format(
                        lno, len(tf_dict), sw.elapsed()))

            for t, utter in enumerate(line.split(
                    dialogue_corpus.utterance_sep)):
                uno += 1
                tokens = utter.split()
                for w in tokens:
                    tf_dict[w] += 1

                if t >= len(wno_stat_per_turn):
                    wno_stat_per_turn.append(SampledSummaryStat())
                wno_stat_per_turn[t].accept(len(tokens))
                wno_stat.accept(len(tokens))

    sorted_vocab = sorted(tf_dict, key=tf_dict.get, reverse=True)

    min_freq_vocab_size, min_freq_vol_size = 0, 0
    vol_size = 0
    for i, w in enumerate(sorted_vocab):
        tf = tf_dict[w]

        if tf >= analysis_args.min_freq:
            min_freq_vocab_size += 1
            min_freq_vol_size += tf

        if i < analysis_args.vocab_size:
            vol_size += tf

        if tf_file:
            tf_file.write('{}\t{}\n'.format(w, tf))
        if frequent_words_file and i < analysis_args.n_frequent_words:
            frequent_words_file.write('{}\n'.format(w))
        if rare_words_file and i > len(
                sorted_vocab) - analysis_args.n_rare_words:
            rare_words_file.write('{}\n'.format(w))

    if frequent_words_file:
        frequent_words_file.close()

    if rare_words_file:
        rare_words_file.close()

    if tf_file:
        tf_file.close()

    print('**** {} ****'.format(os.path.abspath(dialogue_corpus.data_path)))
    print('lines {} | utterances {} | vocab {} tf {}'.format(
        lno, uno, len(tf_dict), wno_stat.get_sum()))
    print('min_freq {} -> {}/{} {:.1f}% - {}/{} {:.1f}%)'.format(
        analysis_args.min_freq, min_freq_vocab_size, len(tf_dict),
        100.0 * min_freq_vocab_size / len(tf_dict), min_freq_vol_size,
        wno_stat.get_sum(), 100.0 * min_freq_vol_size / wno_stat.get_sum()))
    print('vocab_size {}/{} {:.1f}% - {}/{} {:.1f}%)'.format(
        analysis_args.vocab_size, len(tf_dict),
        100.0 * analysis_args.vocab_size / len(tf_dict), vol_size,
        wno_stat.get_sum(), 100.0 * vol_size / wno_stat.get_sum()))

    print('utterances per line: {:.1f}'.format(uno / lno))
    print(
        'utterance_len: avg {:.1f} - stdev {:.1f} - median {} - min {} - max {}'
        .format(wno_stat.get_average(), wno_stat.get_stdev(),
                wno_stat.get_median(), wno_stat.get_min(), wno_stat.get_max()))
    print('utterance_len per turn')
    for t, stat in enumerate(wno_stat_per_turn):
        print(
            '  turn {} - avg {:.1f} - stdev {:.1f} - median {} - min {} - max {}'
            .format(t, stat.get_average(), stat.get_stdev(), stat.get_median(),
                    stat.get_min(), stat.get_max()))
Exemple #5
0
def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--profanity_file', type=str, help='the profanity file')
    parser.add_argument('-f', '--data_file', type=str, required=True, help="the data file")

    params = parser.parse_args()

    if params.profanity_file is None:
        containing_dir, _, _ = fs.split3(os.path.abspath(__file__))
        profanity_file = os.path.join(containing_dir, "profanity_words.txt")
    else:
        profanity_file = params.profanity_file

    one_word_profanities = set()
    multi_word_profanities = set()
    with codecs.getreader('utf-8')(open(profanity_file, 'rb')) as profanity_reader:
        for line in profanity_reader:
            line = line.strip()
            if not line:
                continue

            if ' ' in line:
                multi_word_profanities.add(line)
            else:
                one_word_profanities.add(line)

    print("Profanity words loaded ({} one words/{} multi words)".format(len(one_word_profanities), len(multi_word_profanities)))

    output_file = fs.replace_ext(params.data_file, 'filtered.txt')
    prof1, profN = 0, 0
    with codecs.getreader('utf-8')(open(params.data_file, 'rb')) as data_file, \
            codecs.getwriter('utf-8')(open(output_file, 'wb')) as out_file:
        for line in tqdm(data_file):
            post = line.strip()
            utterances = post.split("\t")
            filtered = False
            for utterance in utterances:

                for word in utterance.split():
                    if word in one_word_profanities:
                        filtered = True
                        prof1 += 1
                        break

                for profanity in multi_word_profanities:
                    if profanity in utterance:
                        filtered = True
                        profN += 1
                        break

            if not filtered:
                s, e = u"\U0001F100", u"\U0001F9FF"
                s2, e2 = u"\U00010000", u"\U0001342E"
                post = re.sub(r'[{}-{}{}-{}]'.format(s, e, s2, e2), '', post)
                post = re.sub(r"[\uFEFF\u2060\u2E18]", '', post)
                line = "\t".join([" ".join(w.strip() for w in u.split()) for u in post.split("\t")])
                out_file.write(line + "\n")

    print("Filtered {} (One word {} / Multi word {})".format(prof1 + profN, prof1, profN))