예제 #1
0
    def bot_func(bot, update, args):
        text = " ".join(args)
        words = utils.tokenize(text)

        seq_1 = data.encode_words(words, emb_dict)
        input_seq = model.pack_input(seq_1, net.emb)

        enc = net.encode(input_seq)

        if prog_args.sample:
            _, tokens = net.decode_chain_sampling(enc,
                                                  input_seq.data[0:1],
                                                  seq_len=data.MAX_TOKENS,
                                                  stop_at_token=end_token)
        else:
            _, tokens = net.decode_chain_argmax(enc,
                                                input_seq.data[0:1],
                                                seq_len=data.MAX_TOKENS,
                                                stop_at_token=end_token)

        if tokens[-1] == end_token:
            tokens = tokens[:-1]

        reply = data.decode_words(tokens, rev_emb_dict)

        if reply:
            reply_text = utils.untokenize(reply)
            bot.send_message(chat_id=update.message.chat_id, text=reply_text)
def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False):
    tokens = data.encode_words(words, emb_dict)
    input_seq = model.pack_input(tokens, net.emb)
    enc = net.encode(input_seq)
    end_token = emb_dict[data.END_TOKEN]
    if use_sampling:
        _, out_tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                                  stop_at_token=end_token)
    else:
        _, out_tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                                stop_at_token=end_token)
    if out_tokens[-1] == end_token:
        out_tokens = out_tokens[:-1]
    out_words = data.decode_words(out_tokens, rev_emb_dict)
    return out_words
 def bot_func(bot, update, args):
     text = " ".join(args)
     words = utils.tokenize(text)
     seq_1 = data.encode_words(words, emb_dict)
     input_seq = model.pack_input(seq_1, net.emb)
     enc = net.encode(input_seq)
     if prog_args.sample:
         _, tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                               stop_at_token=end_token)
     else:
         _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                             stop_at_token=end_token)
     if tokens[-1] == end_token:
         tokens = tokens[:-1]
     reply = data.decode_words(tokens, rev_emb_dict)
     if reply:
         reply_text = utils.untokenize(reply)
         bot.send_message(chat_id=update.message.chat_id, text=reply_text)
예제 #4
0
def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False):
    tokens = data.encode_words(words, emb_dict)
    input_seq = model.pack_input(tokens, net.emb)
    enc = net.encode(input_seq)
    end_token = emb_dict[data.END_TOKEN]
    if use_sampling:
        _, out_tokens = net.decode_chain_sampling(enc,
                                                  input_seq.data[0:1],
                                                  seq_len=data.MAX_TOKENS,
                                                  stop_at_token=end_token)
    else:
        _, out_tokens = net.decode_chain_argmax(enc,
                                                input_seq.data[0:1],
                                                seq_len=data.MAX_TOKENS,
                                                stop_at_token=end_token)
    if out_tokens[-1] == end_token:
        out_tokens = out_tokens[:-1]
    out_words = data.decode_words(out_tokens, rev_emb_dict)
    return out_words
예제 #5
0
 def test_encode_words(self):
     res = data.encode_words(['a', 'b', 'c'], self.emb_dict)
     self.assertEqual(res, [0, 3, 4, 2, 1])
예제 #6
0
 def test_encode_words(self):
     res = data.encode_words(["a", "b", "c"], self.emb_dict)
     self.assertEqual(res, [0, 3, 4, 2, 1])