예제 #1
0
    def bot_func(bot, update, args):
        text = " ".join(args)
        words = utils.tokenize(text)

        seq_1 = data.encode_words(words, emb_dict)
        input_seq = model.pack_input(seq_1, net.emb)

        enc = net.encode(input_seq)

        if prog_args.sample:
            _, tokens = net.decode_chain_sampling(enc,
                                                  input_seq.data[0:1],
                                                  seq_len=data.MAX_TOKENS,
                                                  stop_at_token=end_token)
        else:
            _, tokens = net.decode_chain_argmax(enc,
                                                input_seq.data[0:1],
                                                seq_len=data.MAX_TOKENS,
                                                stop_at_token=end_token)

        if tokens[-1] == end_token:
            tokens = tokens[:-1]

        reply = data.decode_words(tokens, rev_emb_dict)

        if reply:
            reply_text = utils.untokenize(reply)
            bot.send_message(chat_id=update.message.chat_id, text=reply_text)
 def bot_func(bot, update, args):
     text = " ".join(args)
     words = utils.tokenize(text)
     seq_1 = data.encode_words(words, emb_dict)
     input_seq = model.pack_input(seq_1, net.emb)
     enc = net.encode(input_seq)
     if prog_args.sample:
         _, tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                               stop_at_token=end_token)
     else:
         _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                             stop_at_token=end_token)
     if tokens[-1] == end_token:
         tokens = tokens[:-1]
     reply = data.decode_words(tokens, rev_emb_dict)
     if reply:
         reply_text = utils.untokenize(reply)
         bot.send_message(chat_id=update.message.chat_id, text=reply_text)
    logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", required=True, help="Model name to load")
    parser.add_argument("-s", "--string", help="String to process, otherwise will loop")
    parser.add_argument("--sample", default=False, action="store_true", help="Enable sampling generation instead of argmax")
    parser.add_argument("--self", type=int, default=1, help="Enable self-loop mode with given amount of phrases.")
    args = parser.parse_args()

    emb_dict = data.load_emb_dict(os.path.dirname(args.model))
    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE)
    net.load_state_dict(torch.load(args.model))

    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    while True:
        if args.string:
            input_string = args.string
        else:
            input_string = input(">>> ")
        if not input_string:
            break

        words = utils.tokenize(input_string)
        for _ in range(args.self):
            words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=args.sample)
            print(utils.untokenize(words))

        if args.string:
            break
    pass
예제 #4
0
    args = parser.parse_args()

    emb_dict = data.load_emb_dict(os.path.dirname(args.model))
    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM,
                            dict_size=len(emb_dict),
                            hid_size=model.HIDDEN_STATE_SIZE)
    net.load_state_dict(torch.load(args.model))

    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    while True:
        if args.string:
            input_string = args.string
        else:
            input_string = input(">>> ")
        if not input_string:
            break

        words = utils.tokenize(input_string)
        for _ in range(args.self):
            words = words_to_words(words,
                                   emb_dict,
                                   rev_emb_dict,
                                   net,
                                   use_sampling=args.sample)
            print(utils.untokenize(words))

        if args.string:
            break
    pass