コード例 #1
0
ファイル: nlp.py プロジェクト: shark-3/THRED
def _read_emots():
    dir, _, _ = fs.split3(__file__)

    emots = set()
    with codecs.getreader('utf-8')(open(join(dir, 'emots.txt'), 'rb')) as e:
        for line in e:
            emots.add(line.strip())

    return emots
コード例 #2
0
    def _pre_model_creation(self):
        self.config['topic_vocab_file'] = os.path.join(fs.split3(self.config.vocab_file)[0], 'topic_vocab.in')
        self._vocab_table, self.__topic_vocab_table = topical_base.initialize_vocabulary(self.config)

        WordEmbeddings(self.config.embed_conf).create_and_save(
            self.config.vocab_pkl, self.config.vocab_file,
            self.config.embedding_type, self.config.embedding_size)

        if 'original_vocab_size' not in self.config:
            self.config['original_vocab_size'] = self.config.vocab_size
        self.config.vocab_size = len(self._vocab_table)
        self.config.topic_vocab_size = len(self.__topic_vocab_table)

        if self.config.mode == "interactive" and self.config.lda_model_dir is None:
            raise ValueError("In interactive mode, TA-Seq2Seq requires a pretrained LDA model")
コード例 #3
0
ファイル: kv.py プロジェクト: shark-3/THRED
def install_redis(
        install_path='',
        download_url='http://download.redis.io/releases/redis-5.0.3.tar.gz',
        port=6384,
        verbose=True):
    import tarfile
    import tempfile
    from urllib.error import URLError, HTTPError
    import urllib.request as url_request

    from redis.exceptions import ConnectionError
    from util.fs import split3

    proceed_install = True
    r = redis.Redis(host='localhost', port=port)
    try:
        r.ping()
        proceed_install = False
    except ConnectionError:
        pass

    if proceed_install:
        if not install_path:
            tmp_dir = tempfile.mkdtemp(prefix='redis{}'.format(port))
            install_path = os.path.join(tmp_dir, 'redis')

        if not os.path.exists(install_path):
            working_dir, redis_name, _ = split3(install_path)
            redis_tgzfile = os.path.join(working_dir, 'redis.tar.gz')

            if verbose:
                print('Downloading Redis...')

            try:
                with url_request.urlopen(download_url) as resp, \
                        open(redis_tgzfile, 'wb') as out_file:
                    data = resp.read()  # a `bytes` object
                    out_file.write(data)
            except HTTPError as e:
                if verbose:
                    if e.code == 404:
                        print(
                            'The provided URL seems to be broken. Please find a URL for Redis'
                        )
                    else:
                        print('Error code: ', e.code)
                raise ValueError(e)
            except URLError as e:
                if verbose:
                    print('URL error: ', e.reason)
                raise ValueError(e)

            if verbose:
                print('Extracting Redis...')
            with tarfile.open(redis_tgzfile, 'r:gz') as tgz_ref:
                tgz_ref.extractall(working_dir)

            os.remove(redis_tgzfile)

            redis_dir = None
            for f in os.listdir(working_dir):
                if f.lower().startswith('redis'):
                    redis_dir = os.path.join(working_dir, f)
                    break

            if not redis_dir:
                raise ValueError()

            os.rename(redis_dir, os.path.join(working_dir, redis_name))

            if verbose:
                print('Installing Redis...')

            redis_conf_file = os.path.join(install_path, 'redis.conf')
            subprocess.call([
                'sed', '-i', 's/tcp-backlog [0-9]\+$/tcp-backlog 3000/g',
                redis_conf_file
            ])
            subprocess.call([
                'sed', '-i', 's/daemonize no$/daemonize yes/g', redis_conf_file
            ])
            subprocess.call([
                'sed', '-i',
                's/pidfile .*\.pid$/pidfile redis_{}.pid/g'.format(port),
                redis_conf_file
            ])
            subprocess.call([
                'sed', '-i', 's/port 6379/port {}/g'.format(port),
                redis_conf_file
            ])
            subprocess.call(
                ['sed', '-i', 's/save 900 1/save 15000 1/g', redis_conf_file])
            subprocess.call(
                ['sed', '-i', 's/save 300 10/#save 300 10/g', redis_conf_file])
            subprocess.call([
                'sed', '-i', 's/save 60 10000/#save 60 10000/g',
                redis_conf_file
            ])
            subprocess.call([
                'sed', '-i',
                's/logfile ""/logfile "redis_{}.log"/g'.format(port),
                redis_conf_file
            ])
            subprocess.call(['make'], cwd=install_path)

        if verbose:
            print('Running Redis on port {}...'.format(port))
        subprocess.call(['src/redis-server', 'redis.conf'], cwd=install_path)

    return install_path
コード例 #4
0
ファイル: corpus_toolkit.py プロジェクト: shark-3/THRED
def analyze(dialogue_corpus, analysis_args, steps_per_log=100000):
    dir, fname, _ = fs.split3(dialogue_corpus.data_path)

    if analysis_args.n_frequent_words > 0:
        frequent_words_path = os.path.join(
            dir, '{}.top{}'.format(fname, analysis_args.n_frequent_words))
        frequent_words_file = codecs.getwriter("utf-8")(open(
            frequent_words_path, mode="wb"))
    else:
        frequent_words_file = None

    if analysis_args.n_rare_words > 0:
        rare_words_path = os.path.join(
            dir, '{}.bottom{}'.format(fname, analysis_args.n_rare_words))
        rare_words_file = codecs.getwriter("utf-8")(open(rare_words_path,
                                                         mode="wb"))
    else:
        rare_words_file = None

    if analysis_args.save_tf:
        tf_path = os.path.join(dir, '{}.tf'.format(fname))
        tf_file = codecs.getwriter("utf-8")(open(tf_path, mode="wb"))
    else:
        tf_file = None

    tf_dict = defaultdict(int)

    lno, uno = 0, 0
    wno_stat_per_turn = []
    wno_stat = SampledSummaryStat()
    sw = Stopwatch()

    with codecs.getreader("utf-8")(open(dialogue_corpus.data_path,
                                        mode="rb")) as data_file:
        for line in data_file:
            lno += 1

            if lno % steps_per_log == 0:
                logger.info(
                    '{} lines processed - so far vocab {} - time {}'.format(
                        lno, len(tf_dict), sw.elapsed()))

            for t, utter in enumerate(line.split(
                    dialogue_corpus.utterance_sep)):
                uno += 1
                tokens = utter.split()
                for w in tokens:
                    tf_dict[w] += 1

                if t >= len(wno_stat_per_turn):
                    wno_stat_per_turn.append(SampledSummaryStat())
                wno_stat_per_turn[t].accept(len(tokens))
                wno_stat.accept(len(tokens))

    sorted_vocab = sorted(tf_dict, key=tf_dict.get, reverse=True)

    min_freq_vocab_size, min_freq_vol_size = 0, 0
    vol_size = 0
    for i, w in enumerate(sorted_vocab):
        tf = tf_dict[w]

        if tf >= analysis_args.min_freq:
            min_freq_vocab_size += 1
            min_freq_vol_size += tf

        if i < analysis_args.vocab_size:
            vol_size += tf

        if tf_file:
            tf_file.write('{}\t{}\n'.format(w, tf))
        if frequent_words_file and i < analysis_args.n_frequent_words:
            frequent_words_file.write('{}\n'.format(w))
        if rare_words_file and i > len(
                sorted_vocab) - analysis_args.n_rare_words:
            rare_words_file.write('{}\n'.format(w))

    if frequent_words_file:
        frequent_words_file.close()

    if rare_words_file:
        rare_words_file.close()

    if tf_file:
        tf_file.close()

    print('**** {} ****'.format(os.path.abspath(dialogue_corpus.data_path)))
    print('lines {} | utterances {} | vocab {} tf {}'.format(
        lno, uno, len(tf_dict), wno_stat.get_sum()))
    print('min_freq {} -> {}/{} {:.1f}% - {}/{} {:.1f}%)'.format(
        analysis_args.min_freq, min_freq_vocab_size, len(tf_dict),
        100.0 * min_freq_vocab_size / len(tf_dict), min_freq_vol_size,
        wno_stat.get_sum(), 100.0 * min_freq_vol_size / wno_stat.get_sum()))
    print('vocab_size {}/{} {:.1f}% - {}/{} {:.1f}%)'.format(
        analysis_args.vocab_size, len(tf_dict),
        100.0 * analysis_args.vocab_size / len(tf_dict), vol_size,
        wno_stat.get_sum(), 100.0 * vol_size / wno_stat.get_sum()))

    print('utterances per line: {:.1f}'.format(uno / lno))
    print(
        'utterance_len: avg {:.1f} - stdev {:.1f} - median {} - min {} - max {}'
        .format(wno_stat.get_average(), wno_stat.get_stdev(),
                wno_stat.get_median(), wno_stat.get_min(), wno_stat.get_max()))
    print('utterance_len per turn')
    for t, stat in enumerate(wno_stat_per_turn):
        print(
            '  turn {} - avg {:.1f} - stdev {:.1f} - median {} - min {} - max {}'
            .format(t, stat.get_average(), stat.get_stdev(), stat.get_median(),
                    stat.get_min(), stat.get_max()))
コード例 #5
0
ファイル: get_embed.py プロジェクト: shark-3/THRED
def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-e",
        "--embedding_type",
        default="glove",
        help="Embedding type. "
        "The default embeddings are 'glove' and 'word2vec' (see conf/word_embeddings.yml)."
        "However, other types can be added too.")
    parser.add_argument("-d",
                        "--dimensions",
                        default=300,
                        type=int,
                        help="#dimensions of Embedding vectors")
    parser.add_argument("-f",
                        "--embedding_file",
                        help="Path to a compressed embedding file. "
                        "Gensim would be used to load the embeddings.")

    args = parser.parse_args()

    with open("conf/word_embeddings.yml", 'r') as file:
        embed_args = yaml.load(file)

    try:
        validate(args, embed_args)
    except ValueError as e:
        print(e, file=sys.stderr)
        exit(1)

    embed_key = args.embedding_type.lower()
    dim_key = '{}d'.format(args.dimensions)

    if args.embedding_file is not None:
        if embed_key not in embed_args:
            embed_args[embed_key] = {}

        if dim_key not in embed_args[embed_key]:
            embed_args[embed_key][dim_key] = {
                "file": path.abspath(args.embedding_file)
            }
    else:
        if "url" not in embed_args[embed_key][dim_key]:
            print("Seems no need to download any embedding files!!!")
            exit(0)

        containing_dir, _, _ = split3(path.abspath(__file__))
        workspace_dir = path.abspath(
            path.join(containing_dir, path.pardir, "workspace"))
        mkdir_if_not_exists(workspace_dir)
        embed_dir = path.join(workspace_dir, "embeddings")
        mkdir_if_not_exists(embed_dir)

        uncompressed_file = None
        download_url = embed_args[embed_key][dim_key]["url"]
        embed_file = download_url[download_url.rfind("/") + 1:]
        compressed_file = path.join(embed_dir, embed_file)

        if not path.exists(compressed_file):
            download_embeddings(download_url, compressed_file)

        print("Uncompressing the file...")
        uncompress(compressed_file, embed_dir)
        remove(compressed_file)

        for f in listdir(embed_dir):
            if (f.endswith(".txt") or f.endswith(".bin")) and str(
                    args.dimensions) in f:
                uncompressed_file = path.join(embed_dir, f)
                break

        if embed_key == "glove":
            vec_file = replace_ext(uncompressed_file, 'vec')
            if not path.exists(vec_file):
                print("Converting to gensim understandable format...")
                glove2word2vec.glove2word2vec(uncompressed_file, vec_file)

            rm_if_exists(uncompressed_file)
            uncompressed_file = vec_file

        embed_args[embed_key][dim_key]["file"] = uncompressed_file

    with codecs.getwriter("utf-8")(open("conf/word_embeddings.yml",
                                        "wb")) as f:
        yaml.dump(embed_args, f, default_flow_style=True)

    print("Done. Config file successfully updated.")
コード例 #6
0
ファイル: reddit_utils.py プロジェクト: shark-3/THRED
 def __get_bot_file(self):
     dir, _, _ = fs.split3(__file__)
     return os.path.join(dir, 'reddit_bots.txt')