Exemplo n.º 1
0
def evaluate_transformer(model, metrics, test_loader, vocab_desc, vocab_api,
                         f_eval):
    ivocab_api = {v: k for k, v in vocab_api.items()}
    ivocab_desc = {v: k for k, v in vocab_desc.items()}
    device = next(model.parameters()).device
    recall_bleus, prec_bleus = [], []
    local_t = 0
    for old_descs, desc_lens, apiseqs, api_lens in tqdm(test_loader):
        # print("shape desc",old_descs.shape," shape api",apiseqs.shape)
        descs = torch.zeros_like(old_descs)
        descs[:, :-1] = old_descs[:, 1:50]
        descs[descs == 2] = 0
        # print("test No.",local_t)
        if local_t > 1000:
            break

        desc_str = indexes2sent(descs[0].numpy(), vocab_desc)
        # print("test evaluate: desc_str",desc_str)
        src_data, desc_lens = [
            tensor.to(device) for tensor in [descs, desc_lens]
        ]

        # print("source data", src_data)
        e_mask = (src_data != 0).unsqueeze(1).to(device)
        src_data = model.src_embedding(src_data)
        src_data = model.positional_encoder(src_data)
        e_output = model.encoder(src_data, e_mask)

        pred_trg = beam_search(model, device, e_output, e_mask)
        # print("predicted target", pred_trg)
        # print("actual target", apiseqs)
        pred_trg = np.array(pred_trg)
        # print("pred trg:",pred_trg)
        pred_sents, _ = indexes2sent(pred_trg, vocab_api)
        # print("pred sent:",pred_sents)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        ref_str, _ = indexes2sent(apiseqs[0].numpy(), vocab_api)
        ref_tokens = ref_str.split(' ')
        # print("pred token:",pred_tokens)
        # print("actu token:",ref_tokens)
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)
        local_t += 1
        f_eval.write("Batch %d \n" % (local_t))  # print the context
        f_eval.write(f"Query: {desc_str} \n")
        f_eval.write("Target >> %s\n" %
                     (ref_str.replace(" ' ", "'")))  # print the true outputs

    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2 * (prec_bleu * recall_bleu) / (prec_bleu + recall_bleu + 10e-12)

    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f" % (
        recall_bleu, prec_bleu, f1)
    print(report)
    f_eval.write(report + "\n")
    print("Done testing")
Exemplo n.º 2
0
def evaluate(model, metrics, test_loader, vocab_desc, vocab_api, repeat,
             decode_mode, f_eval):
    ivocab_api = {v: k for k, v in vocab_api.items()}
    ivocab_desc = {v: k for k, v in vocab_desc.items()}
    device = next(model.parameters()).device

    recall_bleus, prec_bleus = [], []
    local_t = 0
    for descs, desc_lens, apiseqs, api_lens in tqdm(test_loader):

        if local_t > 100:
            break

        desc_str = indexes2sent(descs[0].numpy(), vocab_desc)
        # print("test evaluate: desc_str",desc_str)
        descs, desc_lens = [tensor.to(device) for tensor in [descs, desc_lens]]
        # print("test evaluate: descs",descs)
        with torch.no_grad():
            sample_words, sample_lens = model.sample(descs, desc_lens, repeat,
                                                     decode_mode)
        # nparray: [repeat x seq_len]
        # print("pred trg",sample_words)
        pred_sents, _ = indexes2sent(sample_words, vocab_api)
        # print("pred sent:",pred_sents)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        ref_str, _ = indexes2sent(apiseqs[0].numpy(), vocab_api)
        ref_tokens = ref_str.split(' ')
        # print("pred token:",pred_tokens)
        # print("actu token:",ref_tokens)
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)

        local_t += 1
        f_eval.write("Batch %d \n" % (local_t))  # print the context
        f_eval.write(f"Query: {desc_str} \n")
        f_eval.write("Target >> %s\n" %
                     (ref_str.replace(" ' ", "'")))  # print the true outputs
        for r_id, pred_sent in enumerate(pred_sents):
            f_eval.write("Sample %d >> %s\n" %
                         (r_id, pred_sent.replace(" ' ", "'")))
        f_eval.write("\n")

    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2 * (prec_bleu * recall_bleu) / (prec_bleu + recall_bleu + 10e-12)

    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f" % (
        recall_bleu, prec_bleu, f1)
    print(report)
    f_eval.write(report + "\n")
    print("Done testing")

    return recall_bleu, prec_bleu
Exemplo n.º 3
0
def evaluate(model, metrics, test_loader, vocab_desc, vocab_api, f_eval,
             repeat):
    ivocab_api = {v: k for k, v in vocab_api.items()}
    ivocab_desc = {v: k for k, v in vocab_desc.items()}

    recall_bleus, prec_bleus = [], []
    local_t = 0
    for descs, apiseqs, desc_lens, api_lens in tqdm(test_loader):

        if local_t > 2000:
            break

        desc_str = indexes2sent(descs[0].numpy(), vocab_desc)

        descs, desc_lens = gVar(descs), gVar(desc_lens)
        sample_words, sample_lens = model.sample(descs, desc_lens, repeat)
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab_api)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        ref_str, _ = indexes2sent(apiseqs[0].numpy(), vocab_api,
                                  vocab_api["<s>"])
        ref_tokens = ref_str.split(' ')

        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)

        local_t += 1
        f_eval.write("Batch %d \n" % (local_t))  # print the context
        f_eval.write("Query: {} \n".format(desc_str))
        f_eval.write("Target >> %s\n" %
                     (ref_str.replace(" ' ", "'")))  # print the true outputs
        for r_id, pred_sent in enumerate(pred_sents):
            f_eval.write("Sample %d >> %s\n" %
                         (r_id, pred_sent.replace(" ' ", "'")))
        f_eval.write("\n")

    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2 * (prec_bleu * recall_bleu) / (prec_bleu + recall_bleu + 10e-12)

    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f" % (
        recall_bleu, prec_bleu, f1)
    print(report)
    f_eval.write(report + "\n")
    print("Done testing")

    return recall_bleu, prec_bleu
Exemplo n.º 4
0
def evaluate(model, test_loader, vocab, repeat, f_eval):
    # TODO:
    z_vec = []
    local_t = 0
    while True:
        batch = test_loader.next_batch()
        if batch is None:
            break
        local_t += 1
        if local_t != 1033:
            continue
        context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch
        context, utt_lens = context[:, :,
                                    1:], utt_lens - 1  # remove the sos token in the context and reduce the context length
        f_eval.write("Batch %d \n" % (local_t))  # print the context
        start = np.maximum(0, context_lens[0] - 5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], vocab, vocab["</s>"],
                                       PAD_token)
            f_eval.write("Context %d-%d: %s\n" %
                         (t_id, floors[0, t_id], context_str))
        # print the true outputs
        ref_str, _ = indexes2sent(response[0], vocab, vocab["</s>"],
                                  vocab["<s>"])
        ref_tokens = ref_str.split(' ')
        f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))

        context, context_lens, utt_lens, floors = gVar(context), gVar(
            context_lens), gVar(utt_lens), gData(floors)
        prior_z, c_repeated = model.sample_fix_z(context, context_lens,
                                                 utt_lens, floors, repeat,
                                                 vocab["<s>"], vocab["</s>"])

        # nparray: [repeat x hidden_z]
        z_vec.append(prior_z)

    return z_vec
Exemplo n.º 5
0
def evaluate(model, metrics, test_loader, vocab, ivocab, f_eval, repeat):

    recall_bleus, prec_bleus, bows_extrema, bows_avg, bows_greedy, intra_dist1s, intra_dist2s, avg_lens, inter_dist1s, inter_dist2s\
        = [], [], [], [], [], [], [], [], [], []
    local_t = 0
    while True:
        batch = test_loader.next_batch()
        if batch is None:
            break
        local_t += 1
        context, context_lens, utt_lens, floors, _, _, _, response, res_lens, _ = batch
        context, utt_lens = context[:, :,
                                    1:], utt_lens - 1  # remove the sos token in the context and reduce the context length
        f_eval.write("Batch %d \n" % (local_t))  # print the context
        start = np.maximum(0, context_lens[0] - 5)
        for t_id in range(start, context.shape[1], 1):
            context_str = indexes2sent(context[0, t_id], vocab, vocab["</s>"],
                                       PAD_token)
            f_eval.write("Context %d-%d: %s\n" %
                         (t_id, floors[0, t_id], context_str))
        # print the true outputs
        ref_str, _ = indexes2sent(response[0], vocab, vocab["</s>"],
                                  vocab["<s>"])
        ref_tokens = ref_str.split(' ')
        f_eval.write("Target >> %s\n" % (ref_str.replace(" ' ", "'")))

        context, context_lens, utt_lens, floors = gVar(context), gVar(
            context_lens), gVar(utt_lens), gData(floors)
        sample_words, sample_lens = model.sample(context, context_lens,
                                                 utt_lens, floors, repeat,
                                                 vocab["<s>"], vocab["</s>"])
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab, vocab["</s>"],
                                     PAD_token)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        for r_id, pred_sent in enumerate(pred_sents):
            f_eval.write("Sample %d >> %s\n" %
                         (r_id, pred_sent.replace(" ' ", "'")))
        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)

        bow_extrema, bow_avg, bow_greedy = metrics.sim_bow(
            sample_words, sample_lens, response[:, 1:], res_lens - 2)
        bows_extrema.append(bow_extrema)
        bows_avg.append(bow_avg)
        bows_greedy.append(bow_greedy)

        intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(
            sample_words, sample_lens)
        intra_dist1s.append(intra_dist1)
        intra_dist2s.append(intra_dist2)
        avg_lens.append(np.mean(sample_lens))
        inter_dist1s.append(inter_dist1)
        inter_dist2s.append(inter_dist2)

        f_eval.write("\n")

    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2 * (prec_bleu * recall_bleu) / (prec_bleu + recall_bleu + 10e-12)
    bow_extrema = float(np.mean(bows_extrema))
    bow_avg = float(np.mean(bows_avg))
    bow_greedy = float(np.mean(bows_greedy))
    intra_dist1 = float(np.mean(intra_dist1s))
    intra_dist2 = float(np.mean(intra_dist2s))
    avg_len = float(np.mean(avg_lens))
    inter_dist1 = float(np.mean(inter_dist1s))
    inter_dist2 = float(np.mean(inter_dist2s))
    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f, bow_extrema %f, bow_avg %f, bow_greedy %f,\
    intra_dist1 %f, intra_dist2 %f, avg_len %f, inter_dist1 %f, inter_dist2 %f (only 1 ref, not final results)" \
    % (recall_bleu, prec_bleu, f1, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2)
    print(report)
    f_eval.write(report + "\n")
    print("Done testing")

    return recall_bleu, prec_bleu, bow_extrema, bow_avg, bow_greedy, intra_dist1, intra_dist2, avg_len, inter_dist1, inter_dist2
Exemplo n.º 6
0
def train(embedder,
          encoder,
          hidvar,
          decoder,
          data_loader,
          vocab,
          n_iters,
          model_dir,
          p_teach_force=0.5,
          save_every=5000,
          sample_every=100,
          print_every=10,
          plot_every=100,
          learning_rate=0.00005):
    start = time.time()
    print_time_start = start
    plot_losses = []
    print_loss_total, print_loss_kl, print_loss_decoder = 0., 0., 0.  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    embedder_optimizer = optim.Adam(embedder.parameters(), lr=learning_rate)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    hidvar_optimizer = optim.Adam(hidvar.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    criterion = nn.NLLLoss(
        weight=None, size_average=True
    )  #, ignore_index=EOS_token) #average over a batch, ignore EOS

    data_iter = iter(data_loader)

    for it in range(1, n_iters + 1):
        q_batch, a_batch, q_lens, a_lens = data_iter.next()

        q_batch, a_batch, q_lens, a_lens = sortbatch(
            q_batch, a_batch, q_lens,
            a_lens)  # !!! important for pack sequence
        # sort sequences according to their lengthes in descending order

        kl_anneal_weight = (math.tanh((it - 3500) / 1000) + 1) / 2

        total_loss, kl_loss, decoder_loss = _train_step(
            q_batch, a_batch, q_lens, a_lens, embedder, encoder, hidvar,
            decoder, embedder_optimizer, encoder_optimizer, hidvar_optimizer,
            decoder_optimizer, criterion, kl_anneal_weight, p_teach_force)

        print_loss_total += total_loss
        print_loss_kl += kl_loss
        print_loss_decoder += decoder_loss
        plot_loss_total += total_loss
        if it % save_every == 0:
            if not os.path.exists('%slatentvar_%s/' % (model_dir, str(it))):
                os.makedirs('%slatentvar_%s/' % (model_dir, str(it)))
            torch.save(f='%slatentvar_%s/embedder.pckl' % (model_dir, str(it)),
                       obj=embedder)
            torch.save(f='%slatentvar_%s/encoder.pckl' % (model_dir, str(it)),
                       obj=encoder)
            torch.save(f='%slatentvar_%s/hidvar.pckl' % (model_dir, str(it)),
                       obj=hidvar)
            torch.save(f='%slatentvar_%s/decoder.pckl' % (model_dir, str(it)),
                       obj=decoder)
        if it % sample_every == 0:
            samp_idx = np.random.choice(len(q_batch), 4)  #pick 4 samples
            for i in samp_idx:
                question, target = q_batch[i].view(1,
                                                   -1), a_batch[i].view(1, -1)
                sampled_sentence = sample(embedder, encoder, hidvar, decoder,
                                          question, vocab)
                ivocab = {v: k for k, v in vocab.items()}
                print('question: %s' % (indexes2sent(
                    question.squeeze().numpy(), ivocab, ignore_tok=EOS_token)))
                print('target: %s' % (indexes2sent(
                    target.squeeze().numpy(), ivocab, ignore_tok=EOS_token)))
                print('predicted: %s' % (sampled_sentence))
        if it % print_every == 0:
            print_loss_total = print_loss_total / print_every
            print_loss_kl = print_loss_kl / print_every
            print_loss_decoder = print_loss_decoder / print_every
            print_time = time.time() - print_time_start
            print_time_start = time.time()
            print(
                'iter %d/%d  step_time:%ds  total_time:%s tol_loss: %.4f kl_loss: %.4f dec_loss: %.4f'
                % (it, n_iters, print_time, timeSince(start, it / n_iters),
                   print_loss_total, print_loss_kl, print_loss_decoder))
            print_loss_total, print_loss_kl, print_loss_decoder = 0, 0, 0
        if it % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
Exemplo n.º 7
0
def train(embedder, encoder, topic_picker, first_word_picker, decoder,
          learning_rate, data_loader, topic_size, voca_size, save_every,
          sample_every, print_every, plot_every, model_dir, vocab):
    start = time.time()
    print_time_start = start
    plot_losses = []
    print_loss_total, print_loss_topic, print_loss_word, print_loss_decoder = 0., 0., 0., 0.

    embedder_optimizer = optim.Adam(embedder.parameters(), lr=learning_rate)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    topic_picker_optimizer = optim.Adam(topic_picker.parameters(),
                                        lr=learning_rate)
    first_word_picker_optimizer = optim.Adam(first_word_picker.parameters(),
                                             lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    nll_loss = nn.NLLLoss()
    CE_loss = torch.nn.CrossEntropyLoss()

    data_iter = iter(data_loader)

    for it in range(1, n_iters + 1):
        batch_q, batch_ans, batch_topic = data_iter.next()

        #anneal weight?
        topic_loss_weight = 0.2
        word_loss_weight = 0.2

        topic_loss, first_word_loss, decoder_loss, total_loss = train_once(
            embedder, encoder, topic_picker, first_word_picker, decoder,
            batch_q, batch_ans, batch_topic, topic_size, voca_size,
            topic_loss_weight, word_loss_weight, embedder_optimizer,
            encoder_optimizer, topic_picker_optimizer,
            first_word_picker_optimizer, decoder_optimizer, nll_loss, CE_loss)

        print_loss_total += total_loss
        print_loss_decoder += decoder_loss
        print_loss_word += first_word_loss
        print_loss_topic += topic_loss

        if it % save_every == 0:
            if not os.path.exists('%sthree_net_%s/' % (model_dir, str(it))):
                os.makedirs('%sthree_net_%s/' % (model_dir, str(it)))
            torch.save(f='%sthree_net_%s/embedder.pckl' % (model_dir, str(it)),
                       obj=embedder)
            torch.save(f='%sthree_net_%s/encoder.pckl' % (model_dir, str(it)),
                       obj=encoder)
            torch.save(f='%sthree_net_%s/topic_picker.pckl' %
                       (model_dir, str(it)),
                       obj=topic_picker)
            torch.save(f='%sthree_net_%s/first_word_picker.pckl' %
                       (model_dir, str(it)),
                       obj=first_word_picker)
            torch.save(f='%sthree_net_%s/decoder.pckl' % (model_dir, str(it)),
                       obj=decoder)
        if it % sample_every == 0:
            samp_idx = np.random.choice(len(batch_q), 4)  #pick 4 samples
            for i in samp_idx:
                question, target = batch_q[i].view(1, -1), batch_ans[i].view(
                    1, -1)
                sampled_sentence = sample(embedder, encoder, topic_picker,
                                          first_word_picker, decoder, question,
                                          vocab)
                ivocab = {v: k for k, v in vocab.items()}
                print('question: %s' % (indexes2sent(
                    question.squeeze().numpy(), ivocab, ignore_tok=EOS_token)))
                print('target: %s' % (indexes2sent(
                    target.squeeze().numpy(), ivocab, ignore_tok=EOS_token)))
                print('predicted: %s' % (sampled_sentence))
        #print and plot
        if it % print_every == 0:
            print_loss_total = print_loss_total / print_every
            print_loss_word = print_loss_word / print_every
            print_loss_topic = print_loss_topic / print_every
            print_loss_decoder = print_loss_decoder / print_every
            print_time = time.time() - print_time_start
            print_time_start = time.time()
            print(
                'iter %d/%d  step_time:%ds  total_time:%s total_loss: %.4f topic_loss: %.4f first_word_loss: %.4f dec_loss: %.4f'
                % (it, n_iters, print_time, timeSince(
                    start, it / n_iters), print_loss_total, print_loss_topic,
                   print_loss_word, print_loss_decoder))

            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            print_loss_total, print_loss_topic, print_loss_word, print_loss_decoder = 0., 0., 0., 0.

    showPlot(plot_losses)