コード例 #1
0
ファイル: cvae_run.py プロジェクト: ArponKundu/UIU_MojiTalk
            kl_l.append(kl_loss)
            bow_l.append(bow_loss)

            if global_step % test_step == 0:
                time_now = strftime("%m-%d %H:%M:%S", gmtime())
                p.put_step(epoch, global_step)
                p.put_list([np.mean(recon_l), np.mean(kl_l), np.mean(bow_l)])
                recon_l = []
                kl_l = []
                bow_l = []
            if global_step % (test_step * 10) == 0:
                """ EVAL and INFER """
                # TEST
                (test_recon_loss, test_kl_loss, test_bow_loss, test_ppl,
                 test_bleu_score, precisions,
                 _) = cvae.infer_and_eval(test_batches)
                p.put_example(cvae)

                p.put_step(epoch, global_step)
                put_eval(test_recon_loss, test_kl_loss, test_bow_loss,
                         test_ppl, test_bleu_score, precisions, "TEST", log_f)

                if kl_weight >= 0.35:
                    path = join(output_dir,
                                "breakpoints/at_step_%d.ckpt" % global_step)
                    save_path = saver.save(sess, path)
            global_step += 1
    """GENERATE"""
    # TRAIN SET
    train_batches = batch_generator(train_data,
                                    start_i,
コード例 #2
0
    saver1.restore(sess1, "classifier/07-16_14-33-58/breakpoints/best_test_loss.ckpt")
graph1.finalize()

"""build data"""
test_data = build_data(test_ori_f, test_rep_f, word2index)
test_batches = batch_generator(test_data, start_i, end_i, batch_size, permutate=False)
train_data = build_data(train_ori_f, train_rep_f, word2index)


global_step = best_step = 1
start_epoch = best_epoch = 1

total_step = (8 * len(train_data[0]) / batch_size)

(test_recon_loss, test_kl_loss, test_bow_loss,
 perplexity, test_bleu_score, precisions, _) = seq2seq.infer_and_eval(test_batches)
p.put_example(seq2seq)
p.put_bleu(
    test_recon_loss, test_kl_loss, test_bow_loss,
    perplexity, test_bleu_score, precisions, "TEST")


lengths, ac, ac5 = seq2seq.policy_gen_eval(test_batches, classifier)
p.put_list([lengths, ac, ac5])

s = 1.
e = 1.
step = 500

for epoch in range(start_epoch, num_epoch + 1):
    train_batches = batch_generator(train_data, start_i, end_i, batch_size, permutate=True)