Exemplo n.º 1
0
def run_model(opt, train, val, ixtoword):

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' +
                  str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        # is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)
                print('\nload successfully\n')

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        # for i in range(34):
        #     valid_index = np.random.choice(
        #         len(val), opt.batch_size)
        #     val_sents = [val[t] for t in valid_index]
        #     val_sents_permutated = add_noise(val_sents, opt)
        #     x_val_batch = prepare_data_for_cnn(
        #         val_sents_permutated, opt)
        #     x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
        #     res = sess.run(res_, feed_dict={
        #                     x_: x_val_batch, x_org_: x_val_batch_org})
        #     if i == 0:
        #         valid_text = res['syn_sent']
        #     else:
        #         valid_text = np.concatenate(
        #             (valid_text, res['syn_sent']), 0)

        # np.savetxt('./text_news/vae_words.txt', valid_text, fmt='%i', delimiter=' ')
        # pdb.set_trace()

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])
                x_batch = prepare_data_for_cnn(sents_permutated,
                                               opt)  # Batch L
                x_batch_org = prepare_data_for_rnn(sents, opt)
                d_loss = 0
                g_loss = 0
                if profile:
                    if uidx % opt.dis_steps == 0:
                        _, d_loss = sess.run(
                            [dis_op, d_loss_],
                            feed_dict={
                                x_: x_batch,
                                x_org_: x_batch_org
                            },
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                    if uidx % opt.gen_steps == 0:
                        _, g_loss = sess.run(
                            [gen_op, g_loss_],
                            feed_dict={
                                x_: x_batch,
                                x_org_: x_batch_org
                            },
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                else:
                    if uidx % opt.dis_steps == 0:
                        _, d_loss = sess.run([dis_op, d_loss_],
                                             feed_dict={
                                                 x_: x_batch,
                                                 x_org_: x_batch_org
                                             })
                    if uidx % opt.gen_steps == 0:
                        _, g_loss = sess.run([gen_op, g_loss_],
                                             feed_dict={
                                                 x_: x_batch,
                                                 x_org_: x_batch_org
                                             })
                ''' validation '''
                if uidx % opt.valid_freq == 0:

                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    x_val_batch = prepare_data_for_cnn(val_sents_permutated,
                                                       opt)
                    x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    d_loss_val = sess.run(d_loss_,
                                          feed_dict={
                                              x_: x_val_batch,
                                              x_org_: x_val_batch_org
                                          })
                    g_loss_val = sess.run(g_loss_,
                                          feed_dict={
                                              x_: x_val_batch,
                                              x_org_: x_val_batch_org
                                          })

                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org
                                   })
                    print("Validation d_loss %f, g_loss %f  mean_dist %f" %
                          (d_loss_val, g_loss_val, res['mean_dist']))
                    print("Sent:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]))  #.encode('utf-8', 'ignore').decode("utf8").strip())
                    print("MMD loss %f, GAN loss %f" %
                          (res['mmd'], res['gan']))
                    # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))

                    for i in range(4):
                        valid_index = np.random.choice(len(val),
                                                       opt.batch_size)
                        val_sents = [val[t] for t in valid_index]
                        val_sents_permutated = add_noise(val_sents, opt)
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
                        res = sess.run(res_,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                        if i == 0:
                            valid_text = res['syn_sent']
                        else:
                            valid_text = np.concatenate(
                                (valid_text, res['syn_sent']), 0)

                    np.savetxt('./text_news/syn_val_words.txt',
                               valid_text,
                               fmt='%i',
                               delimiter=' ')

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['syn_sent']],
                        {0: val_set})
                    print('Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3))
                         for it in (bleu2s, bleu3s, bleu4s)]))

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                    test_writer.add_summary(summary, uidx)
Exemplo n.º 2
0
def run_model(opt, train, val, ixtoword):


    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:0'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())


    uidx = 0
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()



    run_metadata = tf.RunMetadata()


    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)


            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("\nStarting epoch %d\n" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                print "\rIter: %d" % uidx,
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])
                x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L
                if x_batch.shape[0] == opt.batch_size:
                    d_loss = 0
                    g_loss = 0
                    if profile:
                        if uidx % opt.dis_steps == 0:
                            _, d_loss = sess.run([dis_op, d_loss_], feed_dict={x_: x_batch},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata)
                        if uidx % opt.gen_steps == 0:
                            _, g_loss = sess.run([gen_op, g_loss_], feed_dict={x_: x_batch},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata)
                    else:
                        if uidx % opt.dis_steps == 0:
                            _, d_loss = sess.run([dis_op, d_loss_], feed_dict={x_: x_batch})
                        if uidx % opt.gen_steps == 0:
                            _, g_loss = sess.run([gen_op, g_loss_], feed_dict={x_: x_batch})





                if uidx % opt.valid_freq == 0:
                    is_train = True
                    # print('Valid Size:', len(val))
                    valid_index = np.random.choice(len(val), opt.batch_size)

                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)

                    d_loss_val = sess.run(d_loss_, feed_dict={x_: x_val_batch})
                    g_loss_val = sess.run(g_loss_, feed_dict={x_: x_val_batch})


                    res = sess.run(res_, feed_dict={x_: x_val_batch})
                    print("Validation d_loss %f, g_loss %f  mean_dist %f" % (d_loss_val, g_loss_val, res['mean_dist']))
                    print "Sent:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()
                    print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan']))
                    np.savetxt('./text/rec_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
                    if opt.discrimination:
                        print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f']))


                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['syn_sent']], {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])

                    summary = sess.run(merged, feed_dict={x_: x_val_batch})
                    test_writer.add_summary(summary, uidx)





                if uidx % opt.print_freq == 0:
                    #pdb.set_trace()
                    res = sess.run(res_, feed_dict={x_: x_batch})
                    median_dis = np.sqrt(np.median([((x-y)**2).sum() for x in res['real_f'] for y in res['real_f']]))
                    print("Iteration %d: d_loss %f, g_loss %f, mean_dist %f, realdist median %f" % (uidx, d_loss, g_loss, res['mean_dist'], median_dis))
                    np.savetxt('./text/rec_train_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
                    print "Sent:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()

                    summary = sess.run(merged, feed_dict={x_: x_batch})
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                        tf.get_default_graph(),
                        run_meta=run_metadata,
                        tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path, global_step=epoch)
Exemplo n.º 3
0
def main():
    # global n_words
    # Prepare training and testing data
    # loadpath = "./data/three_corpus_small.p"
    loadpath = "./data/hotel_reviews.p"
    x = cPickle.load(open(loadpath, "rb"))
    train, val = x[0], x[1]
    wordtoix, ixtoword = x[2], x[3]

    opt = Options()
    opt.n_words = len(ixtoword) + 1
    ixtoword[opt.n_words - 1] = 'GO_'
    print dict(opt)
    print('Total words: %d' % opt.n_words)

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' + str(
                      (opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        res_, loss_, train_op = auto_encoder(x_, x_org_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    # tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    # writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    # config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                # pdb.set_trace()

                t_vars = tf.trainable_variables()
                # print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                if opt.substitution == 's':
                    sents_permutated = substitute_sent(sents, opt)
                elif opt.substitution == 'p':
                    sents_permutated = permutate_sent(sents, opt)
                elif opt.substitution == 'a':
                    sents_permutated = add_sent(sents, opt)
                elif opt.substitution == 'd':
                    sents_permutated = delete_sent(sents, opt)
                else:
                    sents_permutated = sents

                # sents[0] = np.random.permutation(sents[0])

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt)  # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt)  # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated,
                                                   opt)  # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated,
                                                   opt,
                                                   is_add_GO=False)  # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print

                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                _, loss = sess.run([train_op, loss_],
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org
                                   })

                if uidx % opt.valid_freq == 0:
                    opt.is_train = False
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    if opt.substitution == 's':
                        val_sents_permutated = substitute_sent(val_sents, opt)
                    elif opt.substitution == 'p':
                        val_sents_permutated = permutate_sent(val_sents, opt)
                    elif opt.substitution == 'a':
                        val_sents_permutated = add_sent(val_sents, opt)
                    elif opt.substitution == 'd':
                        val_sents_permutated = delete_sent(val_sents, opt)
                    else:
                        val_sents_permutated = sents

                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(
                            val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_,
                                        feed_dict={
                                            x_: x_val_batch,
                                            x_org_: x_val_batch_org
                                        })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org
                                   })
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))
                    print "Val Orig :" + " ".join(
                        [ixtoword[x] for x in val_sents[0] if x != 0])
                    print "Val Perm :" + " ".join([
                        ixtoword[x] for x in val_sents_permutated[0] if x != 0
                    ])
                    print "Val Recon:" + " ".join(
                        [ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['rec_sents']],
                        {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])
                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                    test_writer.add_summary(summary, uidx)
                    opt.is_train = True

                if uidx % opt.print_freq == 0:
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org
                                   })
                    print "Original     :" + " ".join(
                        [ixtoword[x] for x in sents[0] if x != 0])
                    print "Permutated   :" + " ".join(
                        [ixtoword[x] for x in sents_permutated[0] if x != 0])
                    if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                        print "Reconstructed:" + " ".join([
                            ixtoword[x]
                            for x in res['rec_sents_feed_y'][0] if x != 0
                        ])
                    print "Reconstructed:" + " ".join(
                        [ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org
                                       })
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]

            saver.save(sess, opt.save_path, global_step=epoch)
Exemplo n.º 4
0
def main():
    
    
    opt = Options(args)
    opt_t = Options(args)
    opt_t.n_hid = opt.n_z
    loadpath = (opt.data_dir + "/" + opt.data_name) 
    print "loadpath:" + loadpath
    x = cPickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    wordtoix, ixtoword = x[3], x[4]
    
    
    opt.n_words = len(ixtoword) 
    print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
    print dict(opt)
    print('Total words: %d' % opt.n_words)
    opt.n_words = len(ixtoword)
    opt_t.n_words = len(ixtoword)
    print dict(opt)
    if opt.model == 'cnn_rnn':
        opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
        opt_t.update_params(args)
        print dict(opt_t)
    print('Total words: %d' % opt.n_words)
    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
            tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len])
            z_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.n_z])
            res_, gan_cost_g_, train_op_g = conditional_s2s(src_, tgt_, z_, opt, opt_t)
            merged = tf.summary.merge_all()
    
    
    uidx = 0
    graph_options=tf.GraphOptions(build_cost_model=1)
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options)
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()
    run_metadata = tf.RunMetadata()
    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                
                t_vars = tf.trainable_variables()                  
                if opt.load_from_pretrain:
                    d_vars = [var for var in t_vars if var.name.startswith('d_')]
                    g_vars = [var for var in t_vars if var.name.startswith('g_')]
                    restore_from_save(g_vars, sess, opt, opt.restore_dir + "/save/generator")
                    restore_from_save(d_vars, sess, opt, opt.restore_dir + "/save/discriminator")
                else:
                    loader = restore_from_save(t_vars, sess, opt)  
            except Exception as e:
                print 'Error: '+str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d , loss_g = 0, 0
        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                indice = [(x,x+1) for x in range(7)]
                sents = [train[t] for t in train_index]
                for idx in indice:
                    src = [sents[i][idx[0]] for i in range(opt.batch_size)]
                    tgt = [sents[i][idx[1]] for i in range(opt.batch_size)]
                    src_permutated = src 
                    x_batch = prepare_data_for_cnn(src_permutated, opt) 
                    y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) 
                    if opt.z_prior == 'g':
                        z_batch = np.random.normal(0,1,(opt.n_hid, opt.n_z)).astype('float32')
                    else:
                        z_batch = np.random.uniform(-1,1,(opt.batch_size, opt.n_z)).astype('float32')
                    
                    feed = {src_: x_batch, tgt_: y_batch, z_:z_batch}
                    _, loss_g = sess.run([train_op_g, gan_cost_g_],feed_dict=feed)
                if uidx%opt.print_freq == 0:
                    print("Iteration %d: loss G %f" %(uidx, loss_g))
                    res = sess.run(res_, feed_dict=feed)
                    print "Source:" + u' '.join([ixtoword[x] for x in x_batch[0] if x != 0]).encode('utf-8').strip()
                    print "Target:" + u' '.join([ixtoword[x] for x in y_batch[0] if x != 0]).encode('utf-8').strip()
                    print "Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()
                    print ""
                    sys.stdout.flush()
                    summary = sess.run(merged, feed_dict=feed)
                    train_writer.add_summary(summary, uidx)
                    
                    
                if uidx%opt.valid_freq == 1:
                    VALID_SIZE = 4096
                    valid_multiplier = np.int(np.floor(VALID_SIZE/opt.batch_size))
                    res_all, val_tgt_all, loss_val_d_all, loss_val_g_all = [], [], [], []
                    if opt.global_feature:
                        z_loss_all = []
                    for val_step in range(valid_multiplier):
                        valid_index = np.random.choice(len(test), opt.batch_size)
                        indice = [(x,x+1) for x in range(7)]
                        val_sents = [test[t] for t in valid_index]
                        for idx in indice:
                            val_src = [val_sents[i][idx[0]] for i in range(opt.batch_size)]
                            val_tgt = [val_sents[i][idx[1]] for i in range(opt.batch_size)]
                            val_tgt_all.extend(val_tgt)
                            val_src_permutated = val_src 
                            
                            x_val_batch = prepare_data_for_cnn(val_src, opt) 
                            
                            y_val_batch = prepare_data_for_rnn(val_src, opt_t, is_add_GO = False) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(val_src, opt_t)   
                            if opt.z_prior == 'g':
                                z_val_batch = np.random.normal(0,1,(opt.batch_size, opt.n_z)).astype('float32')
                            else:
                                z_val_batch = np.random.uniform(-1,1,(opt.batch_size, opt.n_z)).astype('float32')
                            feed_val = {src_: x_val_batch, tgt_: y_val_batch, z_:z_val_batch}
                            loss_val_g = sess.run([gan_cost_g_], feed_dict=feed_val)
                            loss_val_g_all.append(loss_val_g)
                            res = sess.run(res_, feed_dict=feed_val)
                            res_all.extend(res['syn_sent'])
                        if opt.global_feature:
                            z_loss_all.append(res['z_loss'])
                        
                    
                    
                    print("Validation:  loss G %f " %( np.mean(loss_val_g_all)))
                    
                    print "Val Source:" + u' '.join([ixtoword[x] for x in val_src[0] if x != 0]).encode('utf-8').strip()
                    print "Val Target :" + u' '.join([ixtoword[x] for x in val_tgt[0] if x != 0]).encode('utf-8').strip()
                    print "Val Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()
                    print ""
                    if opt.global_feature:
                        with open(opt.log_path + '.z.txt', "a") as myfile:
                            myfile.write("Iteration" + str(uidx) + "\n")
                            myfile.write("z_loss %f" %(np.mean(z_loss_all))+ "\n")
                            myfile.write("Val Source:" + u' '.join([ixtoword[x] for x in val_src[0] if x != 0]).encode('utf-8').strip()+ "\n")
                            myfile.write("Val Target :" + u' '.join([ixtoword[x] for x in val_tgt[0] if x != 0]).encode('utf-8').strip()+ "\n")
                            myfile.write("Val Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()+ "\n")
                            myfile.write(np.array2string(res['z'][0], formatter={'float_kind':lambda x: "%.2f" % x})+ "\n")
                            myfile.write(np.array2string(res['z_hat'][0], formatter={'float_kind':lambda x: "%.2f" % x})+ "\n\n")
                    val_set = [prepare_for_bleu(s) for s in val_tgt_all]
                    gen = [prepare_for_bleu(s) for s in res_all]
                    
                    
                    
                    [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus)
                    
                    etp_score, dist_score = cal_entropy(gen)
                    
                    
                    
                    print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)])
                    
                    print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])])
                    print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])])
                    
                    print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) 
                    print ""
                    summary = sess.run(merged, feed_dict=feed_val)
                    summary2 = tf.Summary(value=[tf.Summary.Value(tag="bleu-2", simple_value=bleu2s),tf.Summary.Value(tag="etp-4", simple_value=etp_score[3])])
                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary2, uidx)
                if uidx%opt.save_freq == 0:
                    saver.save(sess, opt.save_path)
Exemplo n.º 5
0
def main():
    # Prepare training and testing data

    loadpath = "./data/"

    src_file = loadpath + "Pairs2M.src.num"
    tgt_file = loadpath + "Pairs2M.tgt.num"
    dic_file = loadpath + "Pairs2M.reddit.dic"

    opt = Options()
    opt_t = Options()

    train, val, test, wordtoix, ixtoword = read_pair_data_full(
        src_file,
        tgt_file,
        dic_file,
        max_num=opt.data_size,
        p_f=loadpath + 'demo.p')
    train = [
        x for x in train
        if 2 < len(x[1]) < opt.maxlen - 4 and 2 < len(x[0]) < opt_t.maxlen - 4
    ]
    val = [
        x for x in val
        if 2 < len(x[1]) < opt.maxlen - 4 and 2 < len(x[0]) < opt_t.maxlen - 4
    ]

    if TEST_FLAG:
        test = [test]
        opt.test_freq = 1

    opt.n_words = len(ixtoword)
    opt_t.n_words = len(ixtoword)
    print dict(opt)
    if opt.model == 'cnn_rnn':
        opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
        opt_t.update_params()
        print dict(opt_t)

    print('Total words: %d' % opt.n_words)

    # load w2v
    if os.path.exists(opt.embedding_path_lime):
        with open(opt.embedding_path_lime, 'rb') as pfile:
            embedding = cPickle.load(pfile)
    else:
        w2v = gensim.models.KeyedVectors.load_word2vec_format(
            opt.embedding_path, binary=True)
        embedding = {
            i: copy.deepcopy(w2v[ixtoword[i]])
            for i in range(opt.n_words) if ixtoword[i] in w2v
        }
        with open(opt.embedding_path_lime, 'wb') as pfile:
            cPickle.dump(embedding, pfile, protocol=cPickle.HIGHEST_PROTOCOL)

    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = tf.placeholder(tf.int32,
                                  shape=[opt.batch_size, opt.sent_len])
            tgt_ = tf.placeholder(tf.int32,
                                  shape=[opt_t.batch_size, opt_t.sent_len])
            res_, gan_cost_d_, train_op_d, gan_cost_g_, train_op_g = dialog_gan(
                src_, tgt_, opt, opt_t)
            merged = tf.summary.merge_all()

    uidx = 0
    graph_options = tf.GraphOptions(build_cost_model=1)
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=graph_options)
    config.gpu_options.per_process_gpu_memory_fraction = 0.95

    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:

                t_vars = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES)  #tf.trainable_variables()

                if opt.load_from_ae:
                    save_keys = tensors_key_in_file(
                        opt.load_path)  #t_var g_W:0    key: W
                    ss = [
                        var for var in t_vars
                        if var.name[2:][:-2] in save_keys.keys()
                    ]
                    ss = [
                        var.name[2:] for var in ss
                        if var.get_shape() == save_keys[var.name[2:][:-2]]
                    ]
                    cc = {
                        var.name[2:][:-2]: var
                        for var in t_vars if var.name[2:] in ss
                    }

                    loader = tf.train.Saver(var_list=cc)
                    loader.restore(sess, opt.load_path)

                    print("Loading variables from '%s'." % opt.load_path)
                    print(
                        "Loaded variables:" + " ".join(
                            [var.name
                             for var in t_vars if var.name[2:] in ss]))
                else:
                    save_keys = tensors_key_in_file(opt.load_path)
                    ss = [
                        var for var in t_vars
                        if var.name[:-2] in save_keys.keys()
                    ]
                    ss = [
                        var.name for var in ss
                        if var.get_shape() == save_keys[var.name[:-2]]
                    ]
                    loader = tf.train.Saver(
                        var_list=[var for var in t_vars if var.name in ss])
                    loader.restore(sess, opt.load_path)
                    print("Loading variables from '%s'." % opt.load_path)
                    print("Loaded variables:" + str(ss))
                    # load reverse model
                    try:
                        save_keys = tensors_key_in_file('./save/rev_model')
                        ss = [
                            var for var in t_vars
                            if var.name[:-2] in save_keys.keys()
                            and 'g_rev_' in var.name
                        ]
                        ss = [
                            var.name for var in ss
                            if var.get_shape() == save_keys[var.name[:-2]]
                        ]
                        loader = tf.train.Saver(
                            var_list=[var for var in t_vars if var.name in ss])
                        loader.restore(sess, './save/rev_model')
                        print(
                            "Loading reverse variables from ./save/rev_model")
                        print("Loaded variables:" + str(ss))
                    except Exception as e:
                        print("No reverse model loaded")

            except Exception as e:
                print 'Error: ' + str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d, loss_g = 0, 0
        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:

                uidx += 1

                if uidx % opt.test_freq == 1:
                    iter_num = np.int(np.floor(len(test) / opt.batch_size)) + 1
                    res_all, test_tgt_all = [], []

                    for i in range(iter_num):
                        test_index = range(i * opt.batch_size,
                                           (i + 1) * opt.batch_size)
                        test_tgt, test_src = zip(
                            *[test[t % len(test)] for t in test_index])
                        test_tgt_all.extend(test_tgt)
                        x_batch = prepare_data_for_cnn(test_src, opt)
                        y_batch = prepare_data_for_rnn(
                            test_tgt, opt_t, is_add_GO=False
                        ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(
                            test_tgt, opt_t)
                        feed = {src_: x_batch, tgt_: y_batch}
                        res = sess.run(res_, feed_dict=feed)
                        res_all.extend(res['syn_sent'])

                    test_set = [prepare_for_bleu(s) for s in test_tgt_all]
                    gen = [prepare_for_bleu(s) for s in res_all]
                    [bleu1s, bleu2s, bleu3s,
                     bleu4s] = cal_BLEU_4(gen, {0: test_set},
                                          is_corpus=opt.is_corpus)
                    [rouge1, rouge2, rouge3, rouge4, rougeL,
                     rouges] = cal_ROUGE(gen, {0: test_set},
                                         is_corpus=opt.is_corpus)
                    etp_score, dist_score = cal_entropy(gen)
                    bleu_nltk = cal_BLEU_4_nltk(gen,
                                                test_set,
                                                is_corpus=opt.is_corpus)
                    rel_score = cal_relevance(gen, test_set, embedding)

                    print 'Test BLEU: ' + ' '.join([
                        str(round(it, 3))
                        for it in (bleu_nltk, bleu1s, bleu2s, bleu3s, bleu4s)
                    ])
                    print 'Test Rouge: ' + ' '.join([
                        str(round(it, 3))
                        for it in (rouge1, rouge2, rouge3, rouge4)
                    ])
                    print 'Test Entropy: ' + ' '.join([
                        str(round(it, 3))
                        for it in (etp_score[0], etp_score[1], etp_score[2],
                                   etp_score[3])
                    ])
                    print 'Test Diversity: ' + ' '.join([
                        str(round(it, 3))
                        for it in (dist_score[0], dist_score[1], dist_score[2],
                                   dist_score[3])
                    ])
                    print 'Test Relevance(G,A,E): ' + ' '.join([
                        str(round(it, 3))
                        for it in (rel_score[0], rel_score[1], rel_score[2])
                    ])
                    print 'Test Avg. length: ' + str(
                        round(
                            np.mean([
                                len([y for y in x if y != 0]) for x in res_all
                            ]), 3))
                    print ''

                    if TEST_FLAG:
                        exit()

                tgt, src = zip(*[train[t] for t in train_index])
                x_batch = prepare_data_for_cnn(src, opt)  # Batch L

                y_batch = prepare_data_for_rnn(
                    tgt, opt_t, is_add_GO=False
                ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(
                    tgt, opt_t)

                feed = {src_: x_batch, tgt_: y_batch}

                if uidx % opt.d_freq == 1:
                    if profile:
                        _, loss_d = sess.run(
                            [train_op_d, gan_cost_d_],
                            feed_dict=feed,
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                    else:
                        _, loss_d = sess.run([train_op_d, gan_cost_d_],
                                             feed_dict=feed)

                if uidx % opt.g_freq == 1:
                    if profile:
                        _, loss_g = sess.run(
                            [train_op_g, gan_cost_g_],
                            feed_dict=feed,
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                    else:
                        _, loss_g = sess.run([train_op_g, gan_cost_g_],
                                             feed_dict=feed)

                if profile:
                    tf.contrib.tfprof.model_analyzer.print_model_analysis(
                        tf.get_default_graph(),
                        run_meta=run_metadata,
                        tfprof_options=tf.contrib.tfprof.model_analyzer.
                        PRINT_ALL_TIMING_MEMORY)
                    exit(0)

                if uidx % opt.valid_freq == 1:
                    VALID_SIZE = 1024
                    valid_multiplier = np.int(
                        np.floor(VALID_SIZE / opt.batch_size))
                    res_all, val_tgt_all, loss_val_d_all, loss_val_g_all = [], [], [], []
                    for val_step in range(valid_multiplier):
                        valid_index = np.random.choice(len(val),
                                                       opt.batch_size)
                        val_tgt, val_src = zip(*[val[t] for t in valid_index])
                        val_tgt_all.extend(val_tgt)
                        x_val_batch = prepare_data_for_cnn(val_src,
                                                           opt)  # Batch L

                        y_val_batch = prepare_data_for_rnn(
                            val_tgt, opt_t, is_add_GO=False
                        ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(
                            val_tgt, opt_t)

                        feed_val = {src_: x_val_batch, tgt_: y_val_batch}
                        loss_val_d, loss_val_g = sess.run(
                            [gan_cost_d_, gan_cost_g_], feed_dict=feed_val)
                        loss_val_d_all.append(loss_val_d)
                        loss_val_g_all.append(loss_val_g)
                        res = sess.run(res_, feed_dict=feed_val)
                        res_all.extend(res['syn_sent'])

                    print("Validation: loss D %f loss G %f " %
                          (np.mean(loss_val_d_all), np.mean(loss_val_g_all)))
                    #print "Val Perm :" + " ".join([ixtoword[x] for x in val_src_permutated[0] if x != 0])
                    print "Val Source:" + u' '.join([
                        ixtoword[x] for x in val_src[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Val Target :" + u' '.join([
                        ixtoword[x] for x in val_tgt[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Val Generated:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]).encode('utf-8').strip()
                    print ""

                    val_set = [prepare_for_bleu(s) for s in val_tgt_all]
                    gen = [prepare_for_bleu(s) for s in res_all]

                    [bleu1s, bleu2s, bleu3s,
                     bleu4s] = cal_BLEU_4(gen, {0: val_set},
                                          is_corpus=opt.is_corpus)
                    [rouge1, rouge2, rouge3, rouge4, rougeL,
                     rouges] = cal_ROUGE(gen, {0: val_set},
                                         is_corpus=opt.is_corpus)
                    etp_score, dist_score = cal_entropy(gen)
                    bleu_nltk = cal_BLEU_4_nltk(gen,
                                                val_set,
                                                is_corpus=opt.is_corpus)
                    rel_score = cal_relevance(gen, val_set, embedding)

                    print 'Val BLEU: ' + ' '.join([
                        str(round(it, 3))
                        for it in (bleu_nltk, bleu1s, bleu2s, bleu3s, bleu4s)
                    ])
                    print 'Val Rouge: ' + ' '.join([
                        str(round(it, 3))
                        for it in (rouge1, rouge2, rouge3, rouge4)
                    ])
                    print 'Val Entropy: ' + ' '.join([
                        str(round(it, 3))
                        for it in (etp_score[0], etp_score[1], etp_score[2],
                                   etp_score[3])
                    ])
                    print 'Val Diversity: ' + ' '.join([
                        str(round(it, 3))
                        for it in (dist_score[0], dist_score[1], dist_score[2],
                                   dist_score[3])
                    ])
                    print 'Val Relevance(G,A,E): ' + ' '.join([
                        str(round(it, 3))
                        for it in (rel_score[0], rel_score[1], rel_score[2])
                    ])
                    print 'Val Avg. length: ' + str(
                        round(
                            np.mean([
                                len([y for y in x if y != 0]) for x in res_all
                            ]), 3))
                    print ""
                    summary = sess.run(merged, feed_dict=feed_val)
                    summary2 = tf.Summary(value=[
                        tf.Summary.Value(tag="bleu-2", simple_value=bleu2s),
                        tf.Summary.Value(tag="rouge-2", simple_value=rouge2),
                        tf.Summary.Value(tag="etp-4",
                                         simple_value=etp_score[3])
                    ])

                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary2, uidx)

                if uidx % opt.print_freq == 1:
                    print("Iteration %d: loss D %f loss G %f" %
                          (uidx, loss_d, loss_g))

                    res = sess.run(res_, feed_dict=feed)

                    if opt.grad_penalty:
                        print "grad_penalty: " + str(res['gp'])
                    print "Source:" + u' '.join([
                        ixtoword[x] for x in x_batch[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Target:" + u' '.join([
                        ixtoword[x] for x in y_batch[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Generated:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]).encode('utf-8').strip()
                    print ""

                    sys.stdout.flush()
                    summary = sess.run(merged, feed_dict=feed)
                    train_writer.add_summary(summary, uidx)

                if uidx % opt.save_freq == 1:
                    saver.save(sess, opt.save_path)
Exemplo n.º 6
0
def run_model(opt, train, val, test, wordtoix, ixtoword):

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt)
        merged = tf.summary.merge_all()
        summary_ext = tf.Summary()

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                loader = restore_from_save(t_vars, sess, opt)
            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt)  # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt)  # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated,
                                                   opt)  # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated,
                                                   opt,
                                                   is_add_GO=False)  # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print

                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                #
                if profile:
                    _, loss = sess.run(
                        [train_op, loss_],
                        feed_dict={
                            x_: x_batch,
                            x_org_: x_batch_org,
                            is_train_: 1
                        },
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                else:
                    _, loss = sess.run([train_op, loss_],
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org,
                                           is_train_: 1
                                       })

                #pdb.set_trace()

                if uidx % opt.valid_freq == 0:
                    is_train = None
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(
                            val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_,
                                        feed_dict={
                                            x_: x_val_batch,
                                            x_org_: x_val_batch_org,
                                            is_train_: is_train
                                        })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org,
                                       is_train_: is_train
                                   })
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))

                    if opt.char:
                        print "Val Orig :" + "".join(
                            [ixtoword[x] for x in val_sents[0] if x != 0])
                        print "Val Perm :" + "".join([
                            ixtoword[x]
                            for x in val_sents_permutated[0] if x != 0
                        ])
                        print "Val Recon:" + "".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])
                        # print "Val Recon one hot:" + "".join([ixtoword[x] for x in res['rec_sents_one_hot'][0] if x != 0])
                    else:
                        print "Val Orig :" + " ".join(
                            [ixtoword[x] for x in val_sents[0] if x != 0])
                        print "Val Perm :" + " ".join([
                            ixtoword[x]
                            for x in val_sents_permutated[0] if x != 0
                        ])
                        print "Val Recon:" + " ".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['rec_sents']],
                        {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])

                    val_set_char = [
                        prepare_for_cer(s, ixtoword) for s in val_sents
                    ]
                    cer = cal_cer([
                        prepare_for_cer(s, ixtoword) for s in res['rec_sents']
                    ], val_set_char)
                    print 'Val CER: ' + str(round(cer, 3))
                    # summary_ext.Value(tag='CER', simple_value=cer)
                    summary_ext = tf.Summary(
                        value=[tf.Summary.Value(tag='CER', simple_value=cer)])
                    # tf.summary.scalar('CER', cer)

                    #if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    #print "Gen Probs:" + " ".join([str(np.round(res['gen_p'][i], 1)) for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])
                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org,
                                           is_train_: is_train
                                       })
                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary_ext, uidx)
                    is_train = True

                if uidx % opt.print_freq == 0:
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org,
                                       is_train_: 1
                                   })

                    # if 1 in res['rec_sents'][0] or 1 in sents[0]:
                    #     pdb.set_trace()
                    if opt.char:
                        print "Original     :" + "".join(
                            [ixtoword[x] for x in sents[0] if x != 0])
                        print "Permutated   :" + "".join([
                            ixtoword[x] for x in sents_permutated[0] if x != 0
                        ])
                        if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                            print "Reconstructed:" + " ".join([
                                ixtoword[x]
                                for x in res['rec_sents_feed_y'][0] if x != 0
                            ])
                        print "Reconstructed:" + "".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    else:
                        print "Original     :" + " ".join(
                            [ixtoword[x] for x in sents[0] if x != 0])
                        print "Permutated   :" + " ".join([
                            ixtoword[x] for x in sents_permutated[0] if x != 0
                        ])
                        if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                            print "Reconstructed:" + " ".join([
                                ixtoword[x]
                                for x in res['rec_sents_feed_y'][0] if x != 0
                            ])
                        print "Reconstructed:" + " ".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org,
                                           is_train_: 1
                                       })
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                            tf.get_default_graph(),
                            run_meta=run_metadata,
                            tfprof_options=tf.contrib.tfprof.model_analyzer.
                            PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path)
Exemplo n.º 7
0
def main():
    #global n_words
    # Prepare training and testing data
    #loadpath = "./data/three_corpus_small.p"
    #loadpath = "./data/three_corpus_corrected_large.p"
    loadpath = "../data/reddit_2m/"
    dic_file = loadpath + "Pairs2M.reddit.dic"

    wordtoix, ixtoword = {}, {}
    print "Start reading dic file . . ."
    if os.path.exists(dic_file):
        print("loading Dictionary")
        counter = 0
        with codecs.open(dic_file, "r", 'utf-8') as f:
            s = f.readline()
            while s:
                s = s.rstrip('\n').rstrip("\r")
                #print("s==",s)
                wordtoix[s] = counter
                ixtoword[counter] = s
                counter += 1
                s = f.readline()

    target, response = [], []
    with codecs.open(args.target, "r", 'utf-8') as f:
        line = f.readline().rstrip("\n").rstrip("\r")
        while line:
            target.append(
                [wordtoix[x] if x in wordtoix else 3 for x in line.split()])
            line = f.readline().rstrip("\n").rstrip("\r")

    with codecs.open(args.response, "r", 'utf-8') as f:
        line = f.readline().rstrip("\n").rstrip("\r")
        while line:
            response.append(
                [wordtoix[x] if x in wordtoix else 3 for x in line.split()])
            line = f.readline().rstrip("\n").rstrip("\r")

    opt = Options()
    # opt_t = Options()

    # opt.test_freq = 1

    # # opt_t.maxlen = 101 #49
    # # opt_t.update_params()

    opt.n_words = len(ixtoword)
    # opt_t.n_words = len(ixtoword)
    # print dict(opt)
    # if opt.model == 'cnn_rnn':
    #     opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
    #     opt_t.update_params()
    #     print dict(opt_t)

    print('Total words: %d' % opt.n_words)

    if os.path.exists(opt.embedding_path_lime):
        with open(opt.embedding_path_lime, 'rb') as pfile:
            embedding = cPickle.load(pfile)
    else:
        w2v = gensim.models.KeyedVectors.load_word2vec_format(
            opt.embedding_path, binary=True)
        #wl = [ixtoword[i] for i in range(opt.n_words) if ixtoword[i] in w2v]
        #w2v[wl].gensim.models.KeyedVectors.save_word2vec_format(opt.embedding_path + '_lime', binary=True)
        embedding = {
            i: copy.deepcopy(w2v[ixtoword[i]])
            for i in range(opt.n_words) if ixtoword[i] in w2v
        }
        with open(opt.embedding_path_lime, 'wb') as pfile:
            cPickle.dump(embedding, pfile, protocol=cPickle.HIGHEST_PROTOCOL)

    test_set = [prepare_for_bleu(s) for s in target]
    res_all = response
    [bleu1s, bleu2s, bleu3s,
     bleu4s] = cal_BLEU_4([prepare_for_bleu(s) for s in res_all],
                          {0: test_set},
                          is_corpus=opt.is_corpus)
    [rouge1, rouge2, rouge3, rouge4, rougeL,
     rouges] = cal_ROUGE([prepare_for_bleu(s) for s in res_all], {0: test_set},
                         is_corpus=opt.is_corpus)
    etp_score, dist_score = cal_entropy([prepare_for_bleu(s) for s in res_all])
    bleu_nltk = cal_BLEU_4_nltk([prepare_for_bleu(s) for s in res_all],
                                test_set,
                                is_corpus=opt.is_corpus)
    rel_score = cal_relevance([prepare_for_bleu(s) for s in res_all], test_set,
                              embedding)

    print 'Test BLEU: ' + ' '.join([
        str(round(it, 3)) for it in (bleu_nltk, bleu1s, bleu2s, bleu3s, bleu4s)
    ])
    print 'Test Rouge: ' + ' '.join(
        [str(round(it, 3)) for it in (rouge1, rouge2, rouge3, rouge4)])
    print 'Test Entropy: ' + ' '.join([
        str(round(it, 3))
        for it in (etp_score[0], etp_score[1], etp_score[2], etp_score[3])
    ])
    print 'Test Diversity: ' + ' '.join([
        str(round(it, 3))
        for it in (dist_score[0], dist_score[1], dist_score[2], dist_score[3])
    ])
    print 'Test Relevance(G,E,A): ' + ' '.join([
        str(round(it, 3)) for it in (rel_score[0], rel_score[1], rel_score[2])
    ])
    print ''
Exemplo n.º 8
0
def run_model(opt, train, val, ixtoword):

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' +
                  str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt)
        merged = tf.summary.merge_all()

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)

    # keep all checkpoints
    saver = tf.train.Saver(max_to_keep=None)

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)
                print('\nload successfully\n')

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        ''' validation '''
        valid_index = np.random.choice(len(val), opt.batch_size)
        val_sents = [val[t] for t in valid_index]

        val_sents_permutated = add_noise(val_sents, opt)

        x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

        d_loss_val = sess.run(d_loss_,
                              feed_dict={
                                  x_: x_val_batch,
                                  x_org_: x_val_batch_org
                              })
        g_loss_val = sess.run(g_loss_,
                              feed_dict={
                                  x_: x_val_batch,
                                  x_org_: x_val_batch_org
                              })

        res = sess.run(res_,
                       feed_dict={
                           x_: x_val_batch,
                           x_org_: x_val_batch_org
                       })
        try:
            print("Validation d_loss %f, g_loss %f  mean_dist %f" %
                  (d_loss_val, g_loss_val, res['mean_dist']))
            print("Sent:" + u' '.join([
                ixtoword[x] for x in res['syn_sent'][0] if x != 0
            ]))  #.encode('utf-8', 'ignore').decode("utf8").strip())
            print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan']))
        except Exception as e:
            print(e)

        # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
        if opt.discrimination:
            print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f']))

        for i in range(268590 // 1000 +
                       1):  # generate 10k sentences # generate 268590
            valid_index = np.random.choice(len(val), opt.batch_size)
            val_sents = [val[t] for t in valid_index]
            val_sents_permutated = add_noise(val_sents, opt)
            x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
            x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
            res = sess.run(res_,
                           feed_dict={
                               x_: x_val_batch,
                               x_org_: x_val_batch_org
                           })
            if i == 0:
                valid_text = res['syn_sent']
            else:
                valid_text = np.concatenate((valid_text, res['syn_sent']), 0)

        valid_text = valid_text[:268590]
        np.savetxt(PATH_TO_SAVE, valid_text, fmt='%i', delimiter=' ')
        print('saved!\n\n\n')
        exit()

        val_set = [prepare_for_bleu(s) for s in val]  #val_sents]
        bleu_prepared = [prepare_for_bleu(s) for s in res['syn_sent']]
        for i in range(len(val_set) // opt.batch_size):
            batch = val_set[i * opt.batch_size:(i + 1) * opt.batch_size]
            [bleu2s, bleu3s, bleu4s] = cal_BLEU(bleu_prepared,
                                                {0: batch})  #val_set})
            print('Val BLEU (2,3,4): ' + ' '.join(
                [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]))
def run_model(opt, train, val, test, wordtoix, ixtoword):


    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())


    uidx = 0
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()



    run_metadata = tf.RunMetadata()


    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)


            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO = False) # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print


                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                #
                if profile:
                    _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata)
                else:
                    _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})

                #pdb.set_trace()

                if uidx % opt.valid_freq == 0:
                    is_train = None
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)


                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    if opt.discrimination:
                        print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f']))
                    print "Val Orig :" + " ".join([ixtoword[x] for x in val_sents[0] if x != 0])
                    #print "Val Perm :" + " ".join([ixtoword[x] for x in val_sents_permutated[0] if x != 0])
                    print "Val Recon:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])

                    # if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    #     print "Org Probs:" + " ".join(
                    #         [ixtoword[x_val_batch_org[0][i]] + '(' + str(np.round(res['all_p'][i], 1)) + ')' for i in
                    #          range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])
                    #     print "Gen Probs:" + " ".join(
                    #         [ixtoword[res['rec_sents'][0][i]] + '(' + str(np.round(res['gen_p'][i], 1)) + ')' for i in
                    #          range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])

                    summary = sess.run(merged, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    test_writer.add_summary(summary, uidx)
                    is_train = True

                def test_input(text):
                    x_input = sent2idx(text, wordtoix, opt)
                    res = sess.run(res_, feed_dict={x_: x_input, x_org_: x_batch_org})
                    print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                if uidx % opt.print_freq == 0:
                    #pdb.set_trace()
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})
                    print "Original     :" + " ".join([ixtoword[x] for x in sents[0] if x != 0])
                    #print "Permutated   :" + " ".join([ixtoword[x] for x in sents_permutated[0] if x != 0])
                    if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                        print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0])
                    print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    # print "Probs:" + " ".join([ixtoword[res['rec_sents'][0][i]] +'(' +str(np.round(res['all_p'][i],2))+')' for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])

                    summary = sess.run(merged, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                        tf.get_default_graph(),
                        run_meta=run_metadata,
                        tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path, global_step=epoch)
Exemplo n.º 10
0
def main():
    #global n_words
    # Prepare training and testing data

    opt = COptions(args)
    opt_t = COptions(args)
    # opt_t.n_hid = opt.n_z

    loadpath = (opt.data_dir + "/" + opt.data_name) #if opt.not_philly else '/hdfs/msrlabs/xiag/pt-data/cons/data_cleaned/twitter_small.p'
    print "loadpath:" + loadpath
    x = cPickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    wordtoix, ixtoword = x[3], x[4]

    if opt.test:
        test_file = opt.data_dir + opt.test_file 
        test = read_test(test_file, wordtoix)
        # test = [ x for x in test if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])]
    # train_filtered = [ x for x in train if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])]
    # val_filtered = [ x for x in val if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])]
    # print ("Train: %d => %d" % (len(train), len(train_filtered)))
    # print ("Val: %d => %d" % (len(val), len(val_filtered)))
    # train, val = train_filtered, val_filtered
    # del train_filtered, val_filtered

    opt.n_words = len(ixtoword) 
    opt_t.n_words = len(ixtoword)
    opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
    opt_t.update_params(args)
    print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
    print dict(opt)
    print('Total words: %d' % opt.n_words)

    # print dict(opt)
    # if opt.model == 'cnn_rnn':
    #     opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
    #     opt_t.update_params(args)
        # print dict(opt_t)


    #for d in ['/gpu:0', '/gpu:1', '/gpu:2', '/gpu:3']:
    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = [tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context)]
            tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len])
            z_ = tf.placeholder(tf.float32, shape=[opt_t.batch_size , opt.n_z * (2 if opt.local_feature else 1)])
            is_train_ = tf.placeholder(tf.bool, name = 'is_train')
            res_1_ = get_features(src_, tgt_, is_train_, opt, opt_t)
            res_2_ = generate_resp(src_, tgt_, z_, is_train_, opt, opt_t)
            merged = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    graph_options=tf.GraphOptions(build_cost_model=1)
    #config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1))
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options)
    # config.gpu_options.per_process_gpu_memory_fraction = 0.70
    #config = tf.ConfigProto(device_count={'GPU':0})
    #config.gpu_options.allow_growth = True

    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()
                t_vars = tf.trainable_variables()  
                #t_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) #tf.trainable_variables()

                # if opt.load_from_pretrain:
                #     d_vars = [var for var in t_vars if var.name.startswith('d_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     l_vars = [var for var in t_vars if var.name.startswith('l_')]
                #     #restore_from_save(g_vars, sess, opt, prefix = 'g_', load_path=opt.restore_dir + "/save/generator2")
                #     restore_from_save(d_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.global_d)
                #     if opt.local_feature:
                #         restore_from_save(l_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.local_d)
                # else:
                loader = restore_from_save(t_vars, sess, opt, load_path = opt.save_path)


            except Exception as e:
                print 'Error: '+str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d , loss_g = 0, 0

        if opt.test:
            iter_num = np.int(np.floor(len(test)/opt.batch_size))+1
            res_all = []
            val_tgt_all =[]
            for i in range(iter_num):
                test_index = range(i * opt.batch_size,(i+1) * opt.batch_size)
                sents = [test[t%len(test)] for t in test_index]
                for idx in range(opt.n_context,opt.num_turn):
                    src = [[sents[i][idx-turn] for i in range(opt.batch_size)] for turn in range(opt.n_context,0,-1)]
                    tgt = [sents[i][idx] for i in range(opt.batch_size)] 
                    val_tgt_all.extend(tgt)
                    if opt.feed_generated and idx!= opt.n_context:
                        src[-1] = [[x for x in p if x!=0] for p in res_all[-opt.batch_size:]]

                    x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] # Batch L
                    y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) 
                    
                    feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) # do not use False
                    res_1 = sess.run(res_1_, feed_dict=feed)
                    z_all = np.array(res_1['z'])

                    
                    feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, z_: z_all, is_train_: 0}) # do not use False
                    res_2 = sess.run(res_2_, feed_dict=feed)
                    res_all.extend(res_2['syn_sent'])

                    # bp()
   
            val_tgt_all = reshaping(val_tgt_all, opt)
            res_all = reshaping(res_all, opt)
            
            save_path = opt.log_path + '.resp.txt'
            if os.path.exists(save_path):
                os.remove(save_path) 
            for idx in range(len(test)*(opt.num_turn-opt.n_context)):
                with open(save_path, "a") as resp_f:
                    resp_f.write(u' '.join([ixtoword[x] for x in res_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') )
            print ("save to:" + save_path)

            if opt.verbose:
                save_path = opt.log_path + '.tgt.txt'
                if os.path.exists(save_path):
                    os.remove(save_path) 
                for idx in range(len(test)*(opt.num_turn-opt.n_context)):
                    with open(save_path, "a") as tgt_f:
                        tgt_f.write(u' '.join([ixtoword[x] for x in val_tgt_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') )
                print ("save to:" + save_path)

            val_set = [prepare_for_bleu(s) for s in val_tgt_all]
            gen = [prepare_for_bleu(s) for s in res_all]
            [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus)
            etp_score, dist_score = cal_entropy(gen)

            # print save_path
            print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)])
            # print 'Val Rouge: ' + ' '.join([str(round(it,3)) for it in (rouge1,rouge2,rouge3,rouge4)])
            print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])])
            print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])])
            # print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])])
            print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) 
            if opt.embedding_score:
                with open("../../ssd0/consistent_dialog/data/GoogleNews-vectors-negative300.bin.p", 'rb') as pfile:
                    embedding = cPickle.load(pfile)
                rel_score = cal_relevance(gen, val_set, embedding)
                print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])])


            if not opt.global_feature or opt.bit == None: exit(0)

        if opt.test:
            iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 
            for int_idx in range(opt.int_num):
                res_all = []
                z1,z2,z3 = [],[],[]
                val_tgt_all =[]
                for i in range(iter_num):
                    test_index = range(i * opt.batch_size,(i+1) * opt.batch_size)
                    sents = [test[t%len(test)] for t in test_index]
                    for idx in range(opt.n_context,opt.num_turn):
                        src = [[sents[i][idx-turn] for i in range(opt.batch_size)] for turn in range(opt.n_context,0,-1)]
                        tgt = [sents[i][idx] for i in range(opt.batch_size)]
                        val_tgt_all.extend(tgt)
                        if opt.feed_generated and idx!= opt.n_context:
                            src[-1] = [[x for x in p if x!=0] for p in res_all[-opt.batch_size:]]

                        x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] # Batch L
                        y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) 
                        feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) # do not use False
                        res_1 = sess.run(res_1_, feed_dict=feed)
                        z_all = np.array(res_1['z'])
                        z_all[:,opt.bit] = np.array([1.0/np.float(opt.int_num-1) * int_idx for _ in range(opt.batch_size)])
                        
                        feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, z_: z_all, is_train_: 0}) # do not use False
                        res_2 = sess.run(res_2_, feed_dict=feed)
                        res_all.extend(res_2['syn_sent'])
                        z1.extend(res_1['z'])                        
                        z2.extend(z_all)
                        z3.extend(res_2['z_hat'])
                        
                        # bp()

                val_tgt_all = reshaping(val_tgt_all, opt)
                res_all = reshaping(res_all, opt)
                z1 = reshaping(z1, opt)
                z2 = reshaping(z2, opt)
                z3 = reshaping(z3, opt)
                
                save_path = opt.log_path  + 'bit' + str(opt.bit) + '.'+ str(1.0/np.float(opt.int_num-1) * int_idx) +'.int.txt'
                if os.path.exists(save_path):
                    os.remove(save_path) 
                for idx in range(len(test)*(opt.num_turn-opt.n_context)):
                    with open(save_path, "a") as resp_f:
                        resp_f.write(u' '.join([ixtoword[x] for x in res_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') )
                print ("save to:" + save_path)

                save_path_z = opt.log_path  + 'bit' + str(opt.bit) + '.'+ str(1.0/np.float(opt.int_num-1) * int_idx) +'.z.txt'
                if os.path.exists(save_path_z):
                    os.remove(save_path_z) 
                for idx in range(len(test)*(opt.num_turn-opt.n_context)):
                    with open(save_path_z, "a") as myfile:
                        #ary = np.array([z1[idx][opt.bit], z2[idx][opt.bit], z3[idx][opt.bit]])
                        #myfile.write(np.array2string(ary, formatter={'float_kind':lambda x: "%.2f" % x}) + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t'))
                        myfile.write(str(z3[idx][opt.bit]) + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t'))

                
                val_set = [prepare_for_bleu(s) for s in val_tgt_all]
                gen = [prepare_for_bleu(s) for s in res_all]
                [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus)
                etp_score, dist_score = cal_entropy(gen)

                print save_path
                print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)])
                # print 'Val Rouge: ' + ' '.join([str(round(it,3)) for it in (rouge1,rouge2,rouge3,rouge4)])
                print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])])
                print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])])
                # print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])])
                print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) 
Exemplo n.º 11
0
def main():
    #global n_words
    # Prepare training and testing data

    opt = COptions(args)
    opt_t = COptions(args)

    loadpath = (opt.data_dir + "/" + opt.data_name)
    print "loadpath:" + loadpath
    x = cPickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    wordtoix, ixtoword = x[3], x[4]

    if opt.test:
        test_file = opt.data_dir + "/newdata2/test.txt"
        test = read_test(test_file, wordtoix)
        test = [
            x for x in test if all(
                [2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)])
        ]
    train_filtered = [
        x for x in train
        if all([2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)])
    ]
    val_filtered = [
        x for x in val
        if all([2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)])
    ]
    print("Train: %d => %d" % (len(train), len(train_filtered)))
    print("Val: %d => %d" % (len(val), len(val_filtered)))
    train, val = train_filtered, val_filtered
    del train_filtered, val_filtered

    opt.n_words = len(ixtoword)
    opt_t.n_words = len(ixtoword)
    opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
    opt_t.update_params(args)
    print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
    print dict(opt)
    print('Total words: %d' % opt.n_words)

    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = [
                tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
                for _ in range(opt.n_context)
            ]
            tgt_ = tf.placeholder(tf.int32,
                                  shape=[opt_t.batch_size, opt_t.sent_len])
            is_train_ = tf.placeholder(tf.bool, name='is_train')
            res_, gan_cost_g_, train_op_g = conditional_s2s(
                src_, tgt_, is_train_, opt, opt_t)
            merged = tf.summary.merge_all()

    uidx = 0
    graph_options = tf.GraphOptions(build_cost_model=1)
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=graph_options)
    config.gpu_options.per_process_gpu_memory_fraction = 0.90

    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:

                t_vars = tf.trainable_variables()

                if opt.load_from_pretrain:
                    d_vars = [
                        var for var in t_vars if var.name.startswith('d_')
                    ]
                    g_vars = [
                        var for var in t_vars if var.name.startswith('g_')
                    ]
                    l_vars = [
                        var for var in t_vars if var.name.startswith('l_')
                    ]
                    restore_from_save(d_vars,
                                      sess,
                                      opt,
                                      load_path=opt.restore_dir + "/save/" +
                                      opt.global_d)
                    if opt.local_feature:
                        restore_from_save(l_vars,
                                          sess,
                                          opt,
                                          load_path=opt.restore_dir +
                                          "/save/" + opt.local_d)
                else:
                    loader = restore_from_save(t_vars,
                                               sess,
                                               opt,
                                               load_path=opt.save_path)

            except Exception as e:
                print 'Error: ' + str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d, loss_g = 0, 0

        if opt.test:
            iter_num = np.int(np.floor(len(test) / opt.batch_size)) + 1
            res_all = []
            for i in range(iter_num):
                test_index = range(i * opt.batch_size,
                                   (i + 1) * opt.batch_size)
                sents = [val[t] for t in test_index]
                for idx in range(opt.n_context, opt.num_turn):
                    src = [[
                        sents[i][idx - turn] for i in range(opt.batch_size)
                    ] for turn in range(opt.n_context, 0, -1)]
                    tgt = [sents[i][idx] for i in range(opt.batch_size)]
                    x_batch = [
                        prepare_data_for_cnn(src_i, opt) for src_i in src
                    ]  # Batch L
                    y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False)
                    feed = merge_two_dicts(
                        {i: d
                         for i, d in zip(src_, x_batch)}, {
                             tgt_: y_batch,
                             is_train_: 0
                         })  # do not use False
                    res = sess.run(res_, feed_dict=feed)
                    res_all.extend(res['syn_sent'])

            # bp()
            res_all = reshaping(res_all, opt)

            for idx in range(len(test) * (opt.num_turn - opt.n_context)):
                with open(opt.log_path + '.resp.txt', "a") as resp_f:
                    resp_f.write(u' '.join([
                        ixtoword[x] for x in res_all[idx] if x != 0 and x != 2
                    ]).encode('utf-8').strip() + (
                        '\n' if idx %
                        (opt.num_turn - opt.n_context) == 0 else '\t'))
            print("save to:" + opt.log_path + '.resp.txt')
            exit(0)

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]
                for idx in range(opt.n_context, opt.num_turn):
                    src = [[
                        sents[i][idx - turn] for i in range(opt.batch_size)
                    ] for turn in range(opt.n_context, 0, -1)]
                    tgt = [sents[i][idx] for i in range(opt.batch_size)]

                    x_batch = [
                        prepare_data_for_cnn(src_i, opt) for src_i in src
                    ]  # Batch L

                    y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False)

                    feed = merge_two_dicts(
                        {i: d
                         for i, d in zip(src_, x_batch)}, {
                             tgt_: y_batch,
                             is_train_: 1
                         })

                    _, loss_g = sess.run([train_op_g, gan_cost_g_],
                                         feed_dict=feed)

                if uidx % opt.print_freq == 0:
                    print("Iteration %d: loss G %f" % (uidx, loss_g))
                    res = sess.run(res_, feed_dict=feed)
                    if opt.global_feature:
                        print "z loss: " + str(res['z_loss'])
                    if "nn" in opt.agg_model:
                        print "z pred_loss: " + str(res['z_loss_pred'])
                    print "Source:" + u' '.join(
                        [ixtoword[x] for s in x_batch
                         for x in s[0] if x != 0]).encode('utf-8').strip()
                    print "Target:" + u' '.join([
                        ixtoword[x] for x in y_batch[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Generated:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]).encode('utf-8').strip()
                    print ""

                    sys.stdout.flush()
                    summary = sess.run(merged, feed_dict=feed)
                    train_writer.add_summary(summary, uidx)

                if uidx % opt.valid_freq == 1:
                    VALID_SIZE = 4096
                    valid_multiplier = np.int(
                        np.floor(VALID_SIZE / opt.batch_size))
                    res_all, val_tgt_all, loss_val_g_all = [], [], []
                    if opt.global_feature:
                        z_loss_all = []
                    for val_step in range(valid_multiplier):
                        valid_index = np.random.choice(len(val),
                                                       opt.batch_size)
                        sents = [val[t] for t in valid_index]
                        for idx in range(opt.n_context, opt.num_turn):
                            src = [[
                                sents[i][idx - turn]
                                for i in range(opt.batch_size)
                            ] for turn in range(opt.n_context, 0, -1)]
                            tgt = [
                                sents[i][idx] for i in range(opt.batch_size)
                            ]

                            val_tgt_all.extend(tgt)

                            x_batch = [
                                prepare_data_for_cnn(src_i, opt)
                                for src_i in src
                            ]  # Batch L

                            y_batch = prepare_data_for_rnn(tgt,
                                                           opt_t,
                                                           is_add_GO=False)

                            feed = merge_two_dicts(
                                {i: d
                                 for i, d in zip(src_, x_batch)}, {
                                     tgt_: y_batch,
                                     is_train_: 0
                                 })  # do not use False

                            loss_val_g = sess.run([gan_cost_g_],
                                                  feed_dict=feed)
                            loss_val_g_all.append(loss_val_g)

                            res = sess.run(res_, feed_dict=feed)
                            res_all.extend(res['syn_sent'])
                        if opt.global_feature:
                            z_loss_all.append(res['z_loss'])

                    print("Validation:  loss G %f " %
                          (np.mean(loss_val_g_all)))
                    if opt.global_feature:
                        print "z loss: " + str(np.mean(z_loss_all))
                    print "Val Source:" + u' '.join(
                        [ixtoword[x] for s in x_batch
                         for x in s[0] if x != 0]).encode('utf-8').strip()
                    print "Val Target:" + u' '.join([
                        ixtoword[x] for x in y_batch[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Val Generated:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]).encode('utf-8').strip()
                    print ""
                    if opt.global_feature:
                        with open(opt.log_path + '.z.txt', "a") as myfile:
                            myfile.write("Iteration" + str(uidx) + "\n")
                            myfile.write("z_loss %f" % (np.mean(z_loss_all)) +
                                         "\n")
                            myfile.write("Val Source:" + u' '.join([
                                ixtoword[x] for s in x_batch
                                for x in s[0] if x != 0
                            ]).encode('utf-8').strip() + "\n")
                            myfile.write("Val Target:" + u' '.join(
                                [ixtoword[x] for x in y_batch[0]
                                 if x != 0]).encode('utf-8').strip() + "\n")
                            myfile.write("Val Generated:" + u' '.join([
                                ixtoword[x]
                                for x in res['syn_sent'][0] if x != 0
                            ]).encode('utf-8').strip() + "\n")
                            myfile.write("Z_input, Z_recon, Z_tgt")
                            myfile.write(
                                np.array2string(res['z'][0],
                                                formatter={
                                                    'float_kind':
                                                    lambda x: "%.2f" % x
                                                }) + "\n")
                            myfile.write(
                                np.array2string(res['z_hat'][0],
                                                formatter={
                                                    'float_kind':
                                                    lambda x: "%.2f" % x
                                                }) + "\n\n")
                            myfile.write(
                                np.array2string(res['z_tgt'][0],
                                                formatter={
                                                    'float_kind':
                                                    lambda x: "%.2f" % x
                                                }) + "\n\n")

                    val_set = [prepare_for_bleu(s) for s in val_tgt_all]
                    gen = [prepare_for_bleu(s) for s in res_all]
                    [bleu1s, bleu2s, bleu3s,
                     bleu4s] = cal_BLEU_4(gen, {0: val_set},
                                          is_corpus=opt.is_corpus)
                    etp_score, dist_score = cal_entropy(gen)

                    print 'Val BLEU: ' + ' '.join([
                        str(round(it, 3))
                        for it in (bleu1s, bleu2s, bleu3s, bleu4s)
                    ])
                    print 'Val Entropy: ' + ' '.join([
                        str(round(it, 3))
                        for it in (etp_score[0], etp_score[1], etp_score[2],
                                   etp_score[3])
                    ])
                    print 'Val Diversity: ' + ' '.join([
                        str(round(it, 3))
                        for it in (dist_score[0], dist_score[1], dist_score[2],
                                   dist_score[3])
                    ])
                    print 'Val Avg. length: ' + str(
                        round(
                            np.mean([
                                len([y for y in x if y != 0]) for x in res_all
                            ]), 3))
                    print ""
                    summary = sess.run(merged, feed_dict=feed)
                    summary2 = tf.Summary(value=[
                        tf.Summary.Value(tag="bleu-2", simple_value=bleu2s),
                        tf.Summary.Value(tag="etp-4",
                                         simple_value=etp_score[3])
                    ])

                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary2, uidx)

                if uidx % opt.save_freq == 0:
                    saver.save(sess, opt.save_path)