Пример #1
0
    def bot_func(bot, update, args):
        text = " ".join(args)
        words = utils.tokenize(text)

        seq_1 = data.encode_words(words, emb_dict)
        input_seq = model.pack_input(seq_1, net.emb)

        enc = net.encode(input_seq)

        if prog_args.sample:
            _, tokens = net.decode_chain_sampling(enc,
                                                  input_seq.data[0:1],
                                                  seq_len=data.MAX_TOKENS,
                                                  stop_at_token=end_token)
        else:
            _, tokens = net.decode_chain_argmax(enc,
                                                input_seq.data[0:1],
                                                seq_len=data.MAX_TOKENS,
                                                stop_at_token=end_token)

        if tokens[-1] == end_token:
            tokens = tokens[:-1]

        reply = data.decode_words(tokens, rev_emb_dict)

        if reply:
            reply_text = utils.untokenize(reply)
            bot.send_message(chat_id=update.message.chat_id, text=reply_text)
def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False):
    tokens = data.encode_words(words, emb_dict)
    input_seq = model.pack_input(tokens, net.emb)
    enc = net.encode(input_seq)
    end_token = emb_dict[data.END_TOKEN]
    if use_sampling:
        _, out_tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                                  stop_at_token=end_token)
    else:
        _, out_tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                                stop_at_token=end_token)
    if out_tokens[-1] == end_token:
        out_tokens = out_tokens[:-1]
    out_words = data.decode_words(out_tokens, rev_emb_dict)
    return out_words
 def bot_func(bot, update, args):
     text = " ".join(args)
     words = utils.tokenize(text)
     seq_1 = data.encode_words(words, emb_dict)
     input_seq = model.pack_input(seq_1, net.emb)
     enc = net.encode(input_seq)
     if prog_args.sample:
         _, tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                               stop_at_token=end_token)
     else:
         _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                             stop_at_token=end_token)
     if tokens[-1] == end_token:
         tokens = tokens[:-1]
     reply = data.decode_words(tokens, rev_emb_dict)
     if reply:
         reply_text = utils.untokenize(reply)
         bot.send_message(chat_id=update.message.chat_id, text=reply_text)
Пример #4
0
def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False):
    tokens = data.encode_words(words, emb_dict)
    input_seq = model.pack_input(tokens, net.emb)
    enc = net.encode(input_seq)
    end_token = emb_dict[data.END_TOKEN]
    if use_sampling:
        _, out_tokens = net.decode_chain_sampling(enc,
                                                  input_seq.data[0:1],
                                                  seq_len=data.MAX_TOKENS,
                                                  stop_at_token=end_token)
    else:
        _, out_tokens = net.decode_chain_argmax(enc,
                                                input_seq.data[0:1],
                                                seq_len=data.MAX_TOKENS,
                                                stop_at_token=end_token)
    if out_tokens[-1] == end_token:
        out_tokens = out_tokens[:-1]
    out_words = data.decode_words(out_tokens, rev_emb_dict)
    return out_words
Пример #5
0
            for p in dial:
                print(" ".join(p))
            print()

    if args.show_train or args.show_dict_freq:
        phrase_pairs, emb_dict = data.load_data(genre_filter=args.genre)

    if args.show_train:
        rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
        train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
        train_data = data.group_train_data(train_data)
        unk_token = emb_dict[data.UNKNOWN_TOKEN]

        print("Training pairs (%d total)" % len(train_data))
        train_data.sort(key=lambda p: len(p[1]), reverse=True)
        for idx, (p1, p2_group) in enumerate(train_data):
            w1 = data.decode_words(p1, rev_emb_dict)
            w2_group = [data.decode_words(p2, rev_emb_dict) for p2 in p2_group]
            print("%d:" % idx, " ".join(w1))
            for w2 in w2_group:
                print("%s:" % (" " * len(str(idx))), " ".join(w2))

    if args.show_dict_freq:
        words_stat = collections.Counter()
        for p1, p2 in phrase_pairs:
            words_stat.update(p1)
        print("Frequency stats for %d tokens in the dict" % len(emb_dict))
        for token, count in words_stat.most_common():
            print("%s: %d" % (token, count))
    pass
                    item_enc = net.get_encoded_item(enc, idx)
                    r_argmax, actions = net.decode_chain_argmax(
                        item_enc,
                        beg_embedding,
                        data.MAX_TOKENS,
                        stop_at_token=end_token)
                    argmax_bleu = utils.calc_bleu_many(actions, ref_indices)
                    bleus_argmax.append(argmax_bleu)

                    if not args.disable_skip:
                        if argmax_bleu > 0.99:
                            skipped_samples += 1
                            continue

                    if not dial_shown:
                        w = data.decode_words(inp_idx, rev_emb_dict)
                        log.info("Input: %s", utils.untokenize(w))
                        ref_words = [
                            utils.untokenize(
                                data.decode_words(ref, rev_emb_dict))
                            for ref in ref_indices
                        ]
                        ref = " ~~|~~ ".join(ref_words)
                        log.info("Refer: %s", ref)
                        w = data.decode_words(actions, rev_emb_dict)
                        log.info("Argmax: %s, bleu=%.4f", utils.untokenize(w),
                                 argmax_bleu)

                    for _ in range(args.samples):
                        r_sample, actions = \
                            net.decode_chain_sampling(
Пример #7
0
                    # Get predicted tokens.
                    seq = torch.max(r.data, dim=1)[1]
                    seq = seq.cpu().numpy()
                # argmax做训练;
                else:
                    r, seq = net.decode_chain_argmax(enc_item, out_seq.data[0:1],
                                                     len(ref_indices))
                    blue_temp = utils.calc_bleu(seq, ref_indices)
                    bleu_sum += blue_temp
                net_results.append(r)
                net_targets.extend(ref_indices)
                bleu_count += 1

                if not dial_shown:
                    # data.decode_words transform IDs to tokens.
                    ref_words = [utils.untokenize(data.decode_words(ref_indices, rev_emb_dict))]
                    log.info("Reference: %s", " ~~|~~ ".join(ref_words))
                    log.info("Predicted: %s, bleu=%.4f", utils.untokenize(data.decode_words(seq, rev_emb_dict)), blue_temp)
                    dial_shown = True
            results_v = torch.cat(net_results)
            results_v = results_v.cuda()
            targets_v = torch.LongTensor(net_targets).to(device)
            targets_v = targets_v.cuda()
            loss_v = F.cross_entropy(results_v, targets_v)
            loss_v = loss_v.cuda()
            loss_v.backward()
            optimiser.step()

            losses.append(loss_v.item())
        bleu = bleu_sum / bleu_count
        bleu_test = run_test(test_data, net, end_token, device)
Пример #8
0
                        qa_info[
                            'pseudo_gold_program_reward'] = pseudo_program_reward

                    # # In this case, the BLEU score is so high that it is not needed to train such case with RL.
                    # if not args.disable_skip and argmax_reward > 0.99:
                    #     skipped_samples += 1
                    #     continue

                    # In one epoch, when model is optimized for the first time, the optimized result is displayed here.
                    # After that, all samples in this epoch don't display anymore.
                    if not dial_shown:
                        # data.decode_words transform IDs to tokens.
                        log.info(
                            "Input: %s",
                            utils.untokenize(
                                data.decode_words(inp_idx, rev_emb_dict)))
                        orig_response = qa_info['orig_response']
                        log.info("orig_response: %s", orig_response)
                        log.info(
                            "Argmax: %s, reward=%.4f",
                            utils.untokenize(
                                data.decode_words(actions, rev_emb_dict)),
                            argmax_reward)

                    sample_logits_list, action_sequence_list = net.beam_decode(
                        hid=item_enc,
                        seq_len=data.MAX_TOKENS,
                        context=context[idx],
                        start_token=beg_token,
                        stop_at_token=end_token,
                        beam_width=args.beam_width,
Пример #9
0
                # argmax做训练;
                else:
                    r, seq = net.decode_chain_argmax(enc_item,
                                                     out_seq.data[0:1],
                                                     len(ref_indices))
                    blue_temp = utils.calc_bleu(seq, ref_indices)
                    bleu_sum += blue_temp
                net_results.append(r)
                net_targets.extend(ref_indices)
                bleu_count += 1

                if not dial_shown:
                    # data.decode_words transform IDs to tokens.
                    ref_words = [
                        utils.untokenize(
                            data.decode_words(ref_indices, rev_emb_dict))
                    ]
                    log.info("Reference: %s", " ~~|~~ ".join(ref_words))
                    log.info(
                        "Predicted: %s, bleu=%.4f",
                        utils.untokenize(data.decode_words(seq, rev_emb_dict)),
                        blue_temp)
                    dial_shown = True
            results_v = torch.cat(net_results)
            results_v = results_v.cuda()
            targets_v = torch.LongTensor(net_targets).to(device)
            targets_v = targets_v.cuda()
            loss_v = F.cross_entropy(results_v, targets_v)
            loss_v = loss_v.cuda()
            loss_v.backward()
            optimiser.step()
            for p in dial:
                print(" ".join(p))
            print()

    if args.show_train or args.show_dict_freq:
        phrase_pairs, emb_dict = data.load_data(genre_filter=args.genre)

    if args.show_train:
        rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
        train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
        train_data = data.group_train_data(train_data)
        unk_token = emb_dict[data.UNKNOWN_TOKEN]

        print("Training pairs (%d total)" % len(train_data))
        train_data.sort(key=lambda p: len(p[1]), reverse=True)
        for idx, (p1, p2_group) in enumerate(train_data):
            w1 = data.decode_words(p1, rev_emb_dict)
            w2_group = [data.decode_words(p2, rev_emb_dict) for p2 in p2_group]
            print("%d:" % idx, " ".join(w1))
            for w2 in w2_group:
                print("%s:" % (" " * len(str(idx))), " ".join(w2))

    if args.show_dict_freq:
        words_stat = collections.Counter()
        for p1, p2 in phrase_pairs:
            words_stat.update(p1)
        print("Frequency stats for %d tokens in the dict" % len(emb_dict))
        for token, count in words_stat.most_common():
            print("%s: %d" % (token, count))
    pass
                    ref_indices = [
                        indices[1:]
                        for indices in output_batch[idx]
                    ]
                    item_enc = net.get_encoded_item(enc, idx)
                    r_argmax, actions = net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS,
                                                                stop_at_token=end_token)
                    argmax_bleu = utils.calc_bleu_many(actions, ref_indices)
                    bleus_argmax.append(argmax_bleu)

                    if not args.disable_skip and argmax_bleu > 0.99:
                        skipped_samples += 1
                        continue

                    if not dial_shown:
                        log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict)))
                        ref_words = [utils.untokenize(data.decode_words(ref, rev_emb_dict)) for ref in ref_indices]
                        log.info("Refer: %s", " ~~|~~ ".join(ref_words))
                        log.info("Argmax: %s, bleu=%.4f", utils.untokenize(data.decode_words(actions, rev_emb_dict)),
                                 argmax_bleu)

                    for _ in range(args.samples):
                        r_sample, actions = net.decode_chain_sampling(item_enc, beg_embedding,
                                                                      data.MAX_TOKENS, stop_at_token=end_token)
                        sample_bleu = utils.calc_bleu_many(actions, ref_indices)

                        if not dial_shown:
                            log.info("Sample: %s, bleu=%.4f", utils.untokenize(data.decode_words(actions, rev_emb_dict)),
                                     sample_bleu)

                        net_policies.append(r_sample)