Exemple #1
0
 def test_comment_to_tokens(self):
     assert (comment_to_tokens(
         'имея внешку 3 из 10. что тебе мешает завести хотя бы одну?') == [
             'имея', 'внешку', '<n>', 'из', '<n>', 'что', 'тебе', 'мешает',
             'завести', 'хотя', 'бы', 'одну', '?', '<eol>', '<eoc>'
         ])
     assert (comment_to_tokens('qwe<br> adasd') == [
         'qwe', 'adasd', '<eol>', '<eoc>'
     ])
Exemple #2
0
def main():
    sequences = []

    files = os.listdir(OPTS.test_data_dir)
    files.sort()
    for i in range(len(files)):
        file = files[i]
        if not re.match('\d+\.txt', file):
            continue

        with open(os.path.join(OPTS.test_data_dir, file), 'rb') as f:
            content = f.read().decode('utf-8')
        for comment in content.split('\n\n'):
            tokens = comment_to_tokens(comment)
            if len(tokens) >= 10 and len(tokens) <= 80:
                sequences.append(tokens)

    def callback(i, res_tokens, seed_tokens):
        print('')
        print(' '.join(seed_tokens))
        print('>>>>>>', ' '.join(res_tokens))

    print('sequences:', len(sequences))
    Generator(OPTS.weights_file, OPTS.id2token_file, OPTS.embedding_size,
              OPTS.hidden_size).generate(
                  sequences,
                  forbidden_tokens=OPTS.forbidden_tokens.split(',')
                  if OPTS.forbidden_tokens else (),
                  max_res_len=OPTS.max_res_len,
                  callback=callback)
Exemple #3
0
def produce_post(thread_id, reply_to):
    seed_tokens = comment_to_tokens(reply_to.comment)
    gen_tokens = generator.generate((seed_tokens, ),
                                    forbidden_tokens=('<unk>', ),
                                    min_res_len=3,
                                    max_res_len=OPTS.max_res_len)[0]

    comment = tokens_to_string(gen_tokens)
    pic_file = None
    if OPTS.pics_dir:
        pic_file = select_random_pic(OPTS.pics_dir)
    posting_queue.put((comment, pic_file, thread_id, reply_to))
Exemple #4
0
def select_thread_posts(board, thread_id, max_posts, min_post_len,
                        max_post_len):
    selected_posts = []

    posts = api.get_thread_posts(board, thread_id)
    random.shuffle(posts)

    for post in posts:
        post.comment = filter_data(post.comment)

        seed_tokens = comment_to_tokens(post.comment)
        if len(seed_tokens) >= min_post_len \
          and len(seed_tokens) <= max_post_len:
            selected_posts.append(post)

            if len(selected_posts) == max_posts:
                break

    return selected_posts
Exemple #5
0
def select_threads(board, max_threads, min_post_len, max_post_len):
    selected_threads = []

    threads = api.get_threads(board)
    for thread_id in threads:
        posts = api.get_thread_posts(board, thread_id)
        if not posts:
            continue

        for post in posts:
            post.comment = filter_data(post.comment)

        seed_tokens = comment_to_tokens(posts[0].comment)
        if len(posts) >= 3 \
          and len(seed_tokens) >= min_post_len \
          and len(seed_tokens) <= max_post_len:
            selected_threads.append((thread_id, posts))

            if len(selected_threads) == max_threads:
                break

    return selected_threads