Beispiel #1
0
 def translate(trainer):
     vaild_index = np.random.choice(len(test_data), args.batchsize)
     val_sents = [test[t] for t in vaild_index]
     val_sents_permutated = denoise.add_noise(val_sents, opt)
     x_val_batch = utils.prepare_data_for_cnn(val_sents_permutated,
                                              opt.maxlen, opt.filter_shape)
     x_val_batch_org = utils.prepare_data_for_rnn(val_sents,
                                                  opt.maxlen,
                                                  opt.sent_len,
                                                  opt.n_words,
                                                  is_add_GO=True)
     mdl = generator
     xp = mdl.xp
     x_val_batch = xp.array(x_val_batch)
     x_val_batch_org = xp.array(x_val_batch_org)
     syn_sents, logits = mdl(x_val_batch, x_val_batch_org)
     prob = [F.softmax(l * opt.L) for l in logits]
     prob = F.stack(prob, 1)
     source_sentence = ' '.join(
         [source_words[int(i)] for i in x_val_batch_org[0] if i != PAD])
     result_sentence = ' '.join([
         source_words.get(int(i), '*NOKEY') for i in syn_sents.data[0]
         if i != PAD
     ])
     prob_sentence = ' '.join(
         [source_words[xp.argmax(p)] for p in prob.data[:, 0]])
     print('# source : ' + source_sentence)
     print('# sent2  : ' + result_sentence)
     print('# prob   : ' + prob_sentence)
Beispiel #2
0
def convert(batch, device):
    if device < 0:
        xp = np
    else:
        xp = cuda.cupy
    x = [x_ for x_ in batch]
    x = denoise.add_noise(x, opt)
    x_org = [x_ for x_ in batch]
    x = utils.prepare_data_for_cnn(x, opt.maxlen, opt.filter_shape)
    x_org = utils.prepare_data_for_rnn(x_org, opt.maxlen, opt.sent_len, opt.n_words, is_add_GO=True)
    x = xp.array(x, dtype=np.int32)
    x_org = xp.array(x_org, dtype=np.int32)
    return {'x': x, 'x_org': x_org}
Beispiel #3
0
def main():
    import utils
    maxlen = 51
    filter_shape = 5
    sent_len = maxlen + 2 * (filter_shape - 1)
    n_words = 5728
    #m = auto_encoder(n_words, maxlen=maxlen)
    m = textGan_generator(n_words, maxlen)
    d = textGan_discriminator(m.embedding, n_words, maxlen=maxlen)
    # m = textGan(n_words, maxlen=maxlen)
    data = np.arange(20 * 2, dtype=np.int32).reshape(2, 20)
    x = utils.prepare_data_for_cnn(data, maxlen, filter_shape)
    x_orig = utils.prepare_data_for_rnn(data, maxlen, sent_len, n_words)
    syn_sents, prob = m(x, x_orig)
Beispiel #4
0
    def update_core(self):
        gen_optimizer = self.get_optimizer('opt_gen')
        dis_optimizer = self.get_optimizer('opt_dis')

        xp = self.gen.xp
        opt = self.opt

        batch = self.get_iterator('main').next()
        batchsize = len(batch)
        x = denoise.add_noise(batch, self.opt)
        x = utils.prepare_data_for_cnn(x, opt.maxlen, opt.filter_shape)
        x_org = utils.prepare_data_for_rnn(batch,
                                           opt.maxlen,
                                           opt.sent_len,
                                           opt.n_words,
                                           is_add_GO=True)
        x = xp.array(x, dtype=np.int32)
        x_org = xp.array(x_org, dtype=np.int32)
        # generator
        syn_sents, prob = self.gen(x, x_org)  # prob: fake data

        # discriminator
        logits_real, H_real = self.dis(x)
        logits_fake, H_fake = self.dis(prob, is_prob=True)

        # one hot vector
        labels_one = xp.ones((batchsize), dtype=xp.int32)  # 1-dim array
        labels_zero = xp.zeros((batchsize), dtype=xp.int32)
        labels_fake = labels_zero  #F.concat([labels_one, labels_zero], axis=1)
        labels_real = labels_one  #F.concat([labels_zero, labels_one], axis=1)
        D_loss = F.softmax_cross_entropy(logits_real, labels_real) + \
            F.softmax_cross_entropy(logits_fake, labels_fake)

        G_loss = compute_MMD_loss(F.squeeze(H_fake), F.squeeze(H_real))

        self.gen.cleargrads()
        G_loss.backward()
        gen_optimizer.update()

        self.dis.cleargrads()
        D_loss.backward()
        dis_optimizer.update()

        H_fake.unchain_backward()
        H_real.unchain_backward()
        prob.unchain_backward()

        chainer.reporter.report({'loss_gen': G_loss})
        chainer.reporter.report({'loss_dis': D_loss})
    def run_epoch(sess,
                  epoch,
                  mode,
                  print_freq=-1,
                  display_sent=-1,
                  train_writer=None):
        fetches_ = {'loss': loss_}
        if mode == 'train':
            x, is_train = train, 1
            fetches_['train_op'] = train_op_
            fetches_['summary'] = merged
        elif mode == 'val':
            assert (print_freq == -1)
            x, is_train = val, None
        elif mode == 'test':
            assert (print_freq == -1)
            x, is_train = test, None

        acc_loss, acc_n = 0.0, 0.0
        local_t = 0
        global_t = epoch * epoch_t  # only used in train mode
        start_time = time.time()
        kf = get_minibatches_idx(len(x), opt.batch_size, shuffle=True)

        for _, index in kf:
            local_t += 1
            global_t += 1

            sents_b = [x[i] for i in index]
            sents_b_n = add_noise(sents_b, opt)
            x_b_org = prepare_data_for_rnn(sents_b, opt)  # Batch L
            x_b = prepare_data_for_cnn(sents_b_n, opt)  # Batch L
            feed_t = {x_: x_b, x_org_: x_b_org, is_train_: is_train}
            fetches = sess.run(fetches_, feed_dict=feed_t)

            batch_size = len(index)
            acc_n += batch_size
            acc_loss += fetches['loss'] * batch_size
            if print_freq > 0 and local_t % print_freq == 0:
                print("%s Iter %d: loss %.4f, time %.1fs" %
                      (mode, local_t, acc_loss / acc_n,
                       time.time() - start_time))
                sys.stdout.flush()
            if mode == 'train' and train_writer != None:
                train_writer.add_summary(fetches['summary'], global_t)

        if display_sent > 0:
            index_d = np.random.choice(len(x), opt.batch_size, replace=False)
            sents_d = [x[i] for i in index_d]
            sents_d_n = add_noise(sents_d, opt)
            x_d_org = prepare_data_for_rnn(sents_d, opt)  # Batch L
            x_d = prepare_data_for_cnn(sents_d_n, opt)  # Batch L
            res = sess.run(res_,
                           feed_dict={
                               x_: x_d,
                               x_org_: x_d_org,
                               is_train_: is_train
                           })
            for i in range(display_sent):
                print(
                    "%s Org: " % mode + " ".join([
                        ixtoword[ix]
                        for ix in sents_d[i] if ix != 0 and ix != 2
                    ]))
                if mode == 'train':
                    print(
                        "%s Rec(feedy): " % mode + " ".join([
                            ixtoword[ix] for ix in res['rec_sents_feed_y'][i]
                            if ix != 0 and ix != 2
                        ]))
                print(
                    "%s Rec: " % mode + " ".join([
                        ixtoword[ix]
                        for ix in res['rec_sents'][i] if ix != 0 and ix != 2
                    ]))

        print("%s Epoch %d: loss %.4f, time %.1fs" %
              (mode, epoch, acc_loss / acc_n, time.time() - start_time))
        return acc_loss / acc_n
Beispiel #6
0
def main():
    
    
    opt = Options(args)
    opt_t = Options(args)
    opt_t.n_hid = opt.n_z
    loadpath = (opt.data_dir + "/" + opt.data_name) 
    print "loadpath:" + loadpath
    x = cPickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    wordtoix, ixtoword = x[3], x[4]
    
    
    opt.n_words = len(ixtoword) 
    print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
    print dict(opt)
    print('Total words: %d' % opt.n_words)
    opt.n_words = len(ixtoword)
    opt_t.n_words = len(ixtoword)
    print dict(opt)
    if opt.model == 'cnn_rnn':
        opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
        opt_t.update_params(args)
        print dict(opt_t)
    print('Total words: %d' % opt.n_words)
    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
            tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len])
            z_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.n_z])
            res_, gan_cost_g_, train_op_g = conditional_s2s(src_, tgt_, z_, opt, opt_t)
            merged = tf.summary.merge_all()
    
    
    uidx = 0
    graph_options=tf.GraphOptions(build_cost_model=1)
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options)
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()
    run_metadata = tf.RunMetadata()
    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                
                t_vars = tf.trainable_variables()                  
                if opt.load_from_pretrain:
                    d_vars = [var for var in t_vars if var.name.startswith('d_')]
                    g_vars = [var for var in t_vars if var.name.startswith('g_')]
                    restore_from_save(g_vars, sess, opt, opt.restore_dir + "/save/generator")
                    restore_from_save(d_vars, sess, opt, opt.restore_dir + "/save/discriminator")
                else:
                    loader = restore_from_save(t_vars, sess, opt)  
            except Exception as e:
                print 'Error: '+str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d , loss_g = 0, 0
        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                indice = [(x,x+1) for x in range(7)]
                sents = [train[t] for t in train_index]
                for idx in indice:
                    src = [sents[i][idx[0]] for i in range(opt.batch_size)]
                    tgt = [sents[i][idx[1]] for i in range(opt.batch_size)]
                    src_permutated = src 
                    x_batch = prepare_data_for_cnn(src_permutated, opt) 
                    y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) 
                    if opt.z_prior == 'g':
                        z_batch = np.random.normal(0,1,(opt.n_hid, opt.n_z)).astype('float32')
                    else:
                        z_batch = np.random.uniform(-1,1,(opt.batch_size, opt.n_z)).astype('float32')
                    
                    feed = {src_: x_batch, tgt_: y_batch, z_:z_batch}
                    _, loss_g = sess.run([train_op_g, gan_cost_g_],feed_dict=feed)
                if uidx%opt.print_freq == 0:
                    print("Iteration %d: loss G %f" %(uidx, loss_g))
                    res = sess.run(res_, feed_dict=feed)
                    print "Source:" + u' '.join([ixtoword[x] for x in x_batch[0] if x != 0]).encode('utf-8').strip()
                    print "Target:" + u' '.join([ixtoword[x] for x in y_batch[0] if x != 0]).encode('utf-8').strip()
                    print "Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()
                    print ""
                    sys.stdout.flush()
                    summary = sess.run(merged, feed_dict=feed)
                    train_writer.add_summary(summary, uidx)
                    
                    
                if uidx%opt.valid_freq == 1:
                    VALID_SIZE = 4096
                    valid_multiplier = np.int(np.floor(VALID_SIZE/opt.batch_size))
                    res_all, val_tgt_all, loss_val_d_all, loss_val_g_all = [], [], [], []
                    if opt.global_feature:
                        z_loss_all = []
                    for val_step in range(valid_multiplier):
                        valid_index = np.random.choice(len(test), opt.batch_size)
                        indice = [(x,x+1) for x in range(7)]
                        val_sents = [test[t] for t in valid_index]
                        for idx in indice:
                            val_src = [val_sents[i][idx[0]] for i in range(opt.batch_size)]
                            val_tgt = [val_sents[i][idx[1]] for i in range(opt.batch_size)]
                            val_tgt_all.extend(val_tgt)
                            val_src_permutated = val_src 
                            
                            x_val_batch = prepare_data_for_cnn(val_src, opt) 
                            
                            y_val_batch = prepare_data_for_rnn(val_src, opt_t, is_add_GO = False) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(val_src, opt_t)   
                            if opt.z_prior == 'g':
                                z_val_batch = np.random.normal(0,1,(opt.batch_size, opt.n_z)).astype('float32')
                            else:
                                z_val_batch = np.random.uniform(-1,1,(opt.batch_size, opt.n_z)).astype('float32')
                            feed_val = {src_: x_val_batch, tgt_: y_val_batch, z_:z_val_batch}
                            loss_val_g = sess.run([gan_cost_g_], feed_dict=feed_val)
                            loss_val_g_all.append(loss_val_g)
                            res = sess.run(res_, feed_dict=feed_val)
                            res_all.extend(res['syn_sent'])
                        if opt.global_feature:
                            z_loss_all.append(res['z_loss'])
                        
                    
                    
                    print("Validation:  loss G %f " %( np.mean(loss_val_g_all)))
                    
                    print "Val Source:" + u' '.join([ixtoword[x] for x in val_src[0] if x != 0]).encode('utf-8').strip()
                    print "Val Target :" + u' '.join([ixtoword[x] for x in val_tgt[0] if x != 0]).encode('utf-8').strip()
                    print "Val Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()
                    print ""
                    if opt.global_feature:
                        with open(opt.log_path + '.z.txt', "a") as myfile:
                            myfile.write("Iteration" + str(uidx) + "\n")
                            myfile.write("z_loss %f" %(np.mean(z_loss_all))+ "\n")
                            myfile.write("Val Source:" + u' '.join([ixtoword[x] for x in val_src[0] if x != 0]).encode('utf-8').strip()+ "\n")
                            myfile.write("Val Target :" + u' '.join([ixtoword[x] for x in val_tgt[0] if x != 0]).encode('utf-8').strip()+ "\n")
                            myfile.write("Val Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()+ "\n")
                            myfile.write(np.array2string(res['z'][0], formatter={'float_kind':lambda x: "%.2f" % x})+ "\n")
                            myfile.write(np.array2string(res['z_hat'][0], formatter={'float_kind':lambda x: "%.2f" % x})+ "\n\n")
                    val_set = [prepare_for_bleu(s) for s in val_tgt_all]
                    gen = [prepare_for_bleu(s) for s in res_all]
                    
                    
                    
                    [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus)
                    
                    etp_score, dist_score = cal_entropy(gen)
                    
                    
                    
                    print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)])
                    
                    print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])])
                    print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])])
                    
                    print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) 
                    print ""
                    summary = sess.run(merged, feed_dict=feed_val)
                    summary2 = tf.Summary(value=[tf.Summary.Value(tag="bleu-2", simple_value=bleu2s),tf.Summary.Value(tag="etp-4", simple_value=etp_score[3])])
                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary2, uidx)
                if uidx%opt.save_freq == 0:
                    saver.save(sess, opt.save_path)
Beispiel #7
0
def main():
    # global n_words
    # Prepare training and testing data
    # loadpath = "./data/three_corpus_small.p"
    loadpath = "./data/hotel_reviews.p"
    x = cPickle.load(open(loadpath, "rb"))
    train, val = x[0], x[1]
    wordtoix, ixtoword = x[2], x[3]

    opt = Options()
    opt.n_words = len(ixtoword) + 1
    ixtoword[opt.n_words - 1] = 'GO_'
    print dict(opt)
    print('Total words: %d' % opt.n_words)

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' + str(
                      (opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        res_, loss_, train_op = auto_encoder(x_, x_org_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    # tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    # writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    # config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                # pdb.set_trace()

                t_vars = tf.trainable_variables()
                # print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                if opt.substitution == 's':
                    sents_permutated = substitute_sent(sents, opt)
                elif opt.substitution == 'p':
                    sents_permutated = permutate_sent(sents, opt)
                elif opt.substitution == 'a':
                    sents_permutated = add_sent(sents, opt)
                elif opt.substitution == 'd':
                    sents_permutated = delete_sent(sents, opt)
                else:
                    sents_permutated = sents

                # sents[0] = np.random.permutation(sents[0])

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt)  # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt)  # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated,
                                                   opt)  # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated,
                                                   opt,
                                                   is_add_GO=False)  # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print

                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                _, loss = sess.run([train_op, loss_],
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org
                                   })

                if uidx % opt.valid_freq == 0:
                    opt.is_train = False
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    if opt.substitution == 's':
                        val_sents_permutated = substitute_sent(val_sents, opt)
                    elif opt.substitution == 'p':
                        val_sents_permutated = permutate_sent(val_sents, opt)
                    elif opt.substitution == 'a':
                        val_sents_permutated = add_sent(val_sents, opt)
                    elif opt.substitution == 'd':
                        val_sents_permutated = delete_sent(val_sents, opt)
                    else:
                        val_sents_permutated = sents

                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(
                            val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_,
                                        feed_dict={
                                            x_: x_val_batch,
                                            x_org_: x_val_batch_org
                                        })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org
                                   })
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))
                    print "Val Orig :" + " ".join(
                        [ixtoword[x] for x in val_sents[0] if x != 0])
                    print "Val Perm :" + " ".join([
                        ixtoword[x] for x in val_sents_permutated[0] if x != 0
                    ])
                    print "Val Recon:" + " ".join(
                        [ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['rec_sents']],
                        {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])
                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                    test_writer.add_summary(summary, uidx)
                    opt.is_train = True

                if uidx % opt.print_freq == 0:
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org
                                   })
                    print "Original     :" + " ".join(
                        [ixtoword[x] for x in sents[0] if x != 0])
                    print "Permutated   :" + " ".join(
                        [ixtoword[x] for x in sents_permutated[0] if x != 0])
                    if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                        print "Reconstructed:" + " ".join([
                            ixtoword[x]
                            for x in res['rec_sents_feed_y'][0] if x != 0
                        ])
                    print "Reconstructed:" + " ".join(
                        [ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org
                                       })
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]

            saver.save(sess, opt.save_path, global_step=epoch)
Beispiel #8
0
def run_model(opt, train, val, test, wordtoix, ixtoword):

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt)
        merged = tf.summary.merge_all()
        summary_ext = tf.Summary()

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                loader = restore_from_save(t_vars, sess, opt)
            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt)  # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt)  # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated,
                                                   opt)  # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated,
                                                   opt,
                                                   is_add_GO=False)  # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print

                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                #
                if profile:
                    _, loss = sess.run(
                        [train_op, loss_],
                        feed_dict={
                            x_: x_batch,
                            x_org_: x_batch_org,
                            is_train_: 1
                        },
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                else:
                    _, loss = sess.run([train_op, loss_],
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org,
                                           is_train_: 1
                                       })

                #pdb.set_trace()

                if uidx % opt.valid_freq == 0:
                    is_train = None
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(
                            val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_,
                                        feed_dict={
                                            x_: x_val_batch,
                                            x_org_: x_val_batch_org,
                                            is_train_: is_train
                                        })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org,
                                       is_train_: is_train
                                   })
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))

                    if opt.char:
                        print "Val Orig :" + "".join(
                            [ixtoword[x] for x in val_sents[0] if x != 0])
                        print "Val Perm :" + "".join([
                            ixtoword[x]
                            for x in val_sents_permutated[0] if x != 0
                        ])
                        print "Val Recon:" + "".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])
                        # print "Val Recon one hot:" + "".join([ixtoword[x] for x in res['rec_sents_one_hot'][0] if x != 0])
                    else:
                        print "Val Orig :" + " ".join(
                            [ixtoword[x] for x in val_sents[0] if x != 0])
                        print "Val Perm :" + " ".join([
                            ixtoword[x]
                            for x in val_sents_permutated[0] if x != 0
                        ])
                        print "Val Recon:" + " ".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['rec_sents']],
                        {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])

                    val_set_char = [
                        prepare_for_cer(s, ixtoword) for s in val_sents
                    ]
                    cer = cal_cer([
                        prepare_for_cer(s, ixtoword) for s in res['rec_sents']
                    ], val_set_char)
                    print 'Val CER: ' + str(round(cer, 3))
                    # summary_ext.Value(tag='CER', simple_value=cer)
                    summary_ext = tf.Summary(
                        value=[tf.Summary.Value(tag='CER', simple_value=cer)])
                    # tf.summary.scalar('CER', cer)

                    #if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    #print "Gen Probs:" + " ".join([str(np.round(res['gen_p'][i], 1)) for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])
                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org,
                                           is_train_: is_train
                                       })
                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary_ext, uidx)
                    is_train = True

                if uidx % opt.print_freq == 0:
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org,
                                       is_train_: 1
                                   })

                    # if 1 in res['rec_sents'][0] or 1 in sents[0]:
                    #     pdb.set_trace()
                    if opt.char:
                        print "Original     :" + "".join(
                            [ixtoword[x] for x in sents[0] if x != 0])
                        print "Permutated   :" + "".join([
                            ixtoword[x] for x in sents_permutated[0] if x != 0
                        ])
                        if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                            print "Reconstructed:" + " ".join([
                                ixtoword[x]
                                for x in res['rec_sents_feed_y'][0] if x != 0
                            ])
                        print "Reconstructed:" + "".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    else:
                        print "Original     :" + " ".join(
                            [ixtoword[x] for x in sents[0] if x != 0])
                        print "Permutated   :" + " ".join([
                            ixtoword[x] for x in sents_permutated[0] if x != 0
                        ])
                        if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                            print "Reconstructed:" + " ".join([
                                ixtoword[x]
                                for x in res['rec_sents_feed_y'][0] if x != 0
                            ])
                        print "Reconstructed:" + " ".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org,
                                           is_train_: 1
                                       })
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                            tf.get_default_graph(),
                            run_meta=run_metadata,
                            tfprof_options=tf.contrib.tfprof.model_analyzer.
                            PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path)
Beispiel #9
0
def main():
    #global n_words
    # Prepare training and testing data

    opt = COptions(args)
    opt_t = COptions(args)

    loadpath = (opt.data_dir + "/" + opt.data_name)
    print "loadpath:" + loadpath
    x = cPickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    wordtoix, ixtoword = x[3], x[4]

    if opt.test:
        test_file = opt.data_dir + "/newdata2/test.txt"
        test = read_test(test_file, wordtoix)
        test = [
            x for x in test if all(
                [2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)])
        ]
    train_filtered = [
        x for x in train
        if all([2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)])
    ]
    val_filtered = [
        x for x in val
        if all([2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)])
    ]
    print("Train: %d => %d" % (len(train), len(train_filtered)))
    print("Val: %d => %d" % (len(val), len(val_filtered)))
    train, val = train_filtered, val_filtered
    del train_filtered, val_filtered

    opt.n_words = len(ixtoword)
    opt_t.n_words = len(ixtoword)
    opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
    opt_t.update_params(args)
    print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
    print dict(opt)
    print('Total words: %d' % opt.n_words)

    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = [
                tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
                for _ in range(opt.n_context)
            ]
            tgt_ = tf.placeholder(tf.int32,
                                  shape=[opt_t.batch_size, opt_t.sent_len])
            is_train_ = tf.placeholder(tf.bool, name='is_train')
            res_, gan_cost_g_, train_op_g = conditional_s2s(
                src_, tgt_, is_train_, opt, opt_t)
            merged = tf.summary.merge_all()

    uidx = 0
    graph_options = tf.GraphOptions(build_cost_model=1)
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=graph_options)
    config.gpu_options.per_process_gpu_memory_fraction = 0.90

    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:

                t_vars = tf.trainable_variables()

                if opt.load_from_pretrain:
                    d_vars = [
                        var for var in t_vars if var.name.startswith('d_')
                    ]
                    g_vars = [
                        var for var in t_vars if var.name.startswith('g_')
                    ]
                    l_vars = [
                        var for var in t_vars if var.name.startswith('l_')
                    ]
                    restore_from_save(d_vars,
                                      sess,
                                      opt,
                                      load_path=opt.restore_dir + "/save/" +
                                      opt.global_d)
                    if opt.local_feature:
                        restore_from_save(l_vars,
                                          sess,
                                          opt,
                                          load_path=opt.restore_dir +
                                          "/save/" + opt.local_d)
                else:
                    loader = restore_from_save(t_vars,
                                               sess,
                                               opt,
                                               load_path=opt.save_path)

            except Exception as e:
                print 'Error: ' + str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d, loss_g = 0, 0

        if opt.test:
            iter_num = np.int(np.floor(len(test) / opt.batch_size)) + 1
            res_all = []
            for i in range(iter_num):
                test_index = range(i * opt.batch_size,
                                   (i + 1) * opt.batch_size)
                sents = [val[t] for t in test_index]
                for idx in range(opt.n_context, opt.num_turn):
                    src = [[
                        sents[i][idx - turn] for i in range(opt.batch_size)
                    ] for turn in range(opt.n_context, 0, -1)]
                    tgt = [sents[i][idx] for i in range(opt.batch_size)]
                    x_batch = [
                        prepare_data_for_cnn(src_i, opt) for src_i in src
                    ]  # Batch L
                    y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False)
                    feed = merge_two_dicts(
                        {i: d
                         for i, d in zip(src_, x_batch)}, {
                             tgt_: y_batch,
                             is_train_: 0
                         })  # do not use False
                    res = sess.run(res_, feed_dict=feed)
                    res_all.extend(res['syn_sent'])

            # bp()
            res_all = reshaping(res_all, opt)

            for idx in range(len(test) * (opt.num_turn - opt.n_context)):
                with open(opt.log_path + '.resp.txt', "a") as resp_f:
                    resp_f.write(u' '.join([
                        ixtoword[x] for x in res_all[idx] if x != 0 and x != 2
                    ]).encode('utf-8').strip() + (
                        '\n' if idx %
                        (opt.num_turn - opt.n_context) == 0 else '\t'))
            print("save to:" + opt.log_path + '.resp.txt')
            exit(0)

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]
                for idx in range(opt.n_context, opt.num_turn):
                    src = [[
                        sents[i][idx - turn] for i in range(opt.batch_size)
                    ] for turn in range(opt.n_context, 0, -1)]
                    tgt = [sents[i][idx] for i in range(opt.batch_size)]

                    x_batch = [
                        prepare_data_for_cnn(src_i, opt) for src_i in src
                    ]  # Batch L

                    y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False)

                    feed = merge_two_dicts(
                        {i: d
                         for i, d in zip(src_, x_batch)}, {
                             tgt_: y_batch,
                             is_train_: 1
                         })

                    _, loss_g = sess.run([train_op_g, gan_cost_g_],
                                         feed_dict=feed)

                if uidx % opt.print_freq == 0:
                    print("Iteration %d: loss G %f" % (uidx, loss_g))
                    res = sess.run(res_, feed_dict=feed)
                    if opt.global_feature:
                        print "z loss: " + str(res['z_loss'])
                    if "nn" in opt.agg_model:
                        print "z pred_loss: " + str(res['z_loss_pred'])
                    print "Source:" + u' '.join(
                        [ixtoword[x] for s in x_batch
                         for x in s[0] if x != 0]).encode('utf-8').strip()
                    print "Target:" + u' '.join([
                        ixtoword[x] for x in y_batch[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Generated:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]).encode('utf-8').strip()
                    print ""

                    sys.stdout.flush()
                    summary = sess.run(merged, feed_dict=feed)
                    train_writer.add_summary(summary, uidx)

                if uidx % opt.valid_freq == 1:
                    VALID_SIZE = 4096
                    valid_multiplier = np.int(
                        np.floor(VALID_SIZE / opt.batch_size))
                    res_all, val_tgt_all, loss_val_g_all = [], [], []
                    if opt.global_feature:
                        z_loss_all = []
                    for val_step in range(valid_multiplier):
                        valid_index = np.random.choice(len(val),
                                                       opt.batch_size)
                        sents = [val[t] for t in valid_index]
                        for idx in range(opt.n_context, opt.num_turn):
                            src = [[
                                sents[i][idx - turn]
                                for i in range(opt.batch_size)
                            ] for turn in range(opt.n_context, 0, -1)]
                            tgt = [
                                sents[i][idx] for i in range(opt.batch_size)
                            ]

                            val_tgt_all.extend(tgt)

                            x_batch = [
                                prepare_data_for_cnn(src_i, opt)
                                for src_i in src
                            ]  # Batch L

                            y_batch = prepare_data_for_rnn(tgt,
                                                           opt_t,
                                                           is_add_GO=False)

                            feed = merge_two_dicts(
                                {i: d
                                 for i, d in zip(src_, x_batch)}, {
                                     tgt_: y_batch,
                                     is_train_: 0
                                 })  # do not use False

                            loss_val_g = sess.run([gan_cost_g_],
                                                  feed_dict=feed)
                            loss_val_g_all.append(loss_val_g)

                            res = sess.run(res_, feed_dict=feed)
                            res_all.extend(res['syn_sent'])
                        if opt.global_feature:
                            z_loss_all.append(res['z_loss'])

                    print("Validation:  loss G %f " %
                          (np.mean(loss_val_g_all)))
                    if opt.global_feature:
                        print "z loss: " + str(np.mean(z_loss_all))
                    print "Val Source:" + u' '.join(
                        [ixtoword[x] for s in x_batch
                         for x in s[0] if x != 0]).encode('utf-8').strip()
                    print "Val Target:" + u' '.join([
                        ixtoword[x] for x in y_batch[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Val Generated:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]).encode('utf-8').strip()
                    print ""
                    if opt.global_feature:
                        with open(opt.log_path + '.z.txt', "a") as myfile:
                            myfile.write("Iteration" + str(uidx) + "\n")
                            myfile.write("z_loss %f" % (np.mean(z_loss_all)) +
                                         "\n")
                            myfile.write("Val Source:" + u' '.join([
                                ixtoword[x] for s in x_batch
                                for x in s[0] if x != 0
                            ]).encode('utf-8').strip() + "\n")
                            myfile.write("Val Target:" + u' '.join(
                                [ixtoword[x] for x in y_batch[0]
                                 if x != 0]).encode('utf-8').strip() + "\n")
                            myfile.write("Val Generated:" + u' '.join([
                                ixtoword[x]
                                for x in res['syn_sent'][0] if x != 0
                            ]).encode('utf-8').strip() + "\n")
                            myfile.write("Z_input, Z_recon, Z_tgt")
                            myfile.write(
                                np.array2string(res['z'][0],
                                                formatter={
                                                    'float_kind':
                                                    lambda x: "%.2f" % x
                                                }) + "\n")
                            myfile.write(
                                np.array2string(res['z_hat'][0],
                                                formatter={
                                                    'float_kind':
                                                    lambda x: "%.2f" % x
                                                }) + "\n\n")
                            myfile.write(
                                np.array2string(res['z_tgt'][0],
                                                formatter={
                                                    'float_kind':
                                                    lambda x: "%.2f" % x
                                                }) + "\n\n")

                    val_set = [prepare_for_bleu(s) for s in val_tgt_all]
                    gen = [prepare_for_bleu(s) for s in res_all]
                    [bleu1s, bleu2s, bleu3s,
                     bleu4s] = cal_BLEU_4(gen, {0: val_set},
                                          is_corpus=opt.is_corpus)
                    etp_score, dist_score = cal_entropy(gen)

                    print 'Val BLEU: ' + ' '.join([
                        str(round(it, 3))
                        for it in (bleu1s, bleu2s, bleu3s, bleu4s)
                    ])
                    print 'Val Entropy: ' + ' '.join([
                        str(round(it, 3))
                        for it in (etp_score[0], etp_score[1], etp_score[2],
                                   etp_score[3])
                    ])
                    print 'Val Diversity: ' + ' '.join([
                        str(round(it, 3))
                        for it in (dist_score[0], dist_score[1], dist_score[2],
                                   dist_score[3])
                    ])
                    print 'Val Avg. length: ' + str(
                        round(
                            np.mean([
                                len([y for y in x if y != 0]) for x in res_all
                            ]), 3))
                    print ""
                    summary = sess.run(merged, feed_dict=feed)
                    summary2 = tf.Summary(value=[
                        tf.Summary.Value(tag="bleu-2", simple_value=bleu2s),
                        tf.Summary.Value(tag="etp-4",
                                         simple_value=etp_score[3])
                    ])

                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary2, uidx)

                if uidx % opt.save_freq == 0:
                    saver.save(sess, opt.save_path)
    def run_epoch(sess,
                  epoch,
                  mode,
                  print_freq=-1,
                  display_sent=-1,
                  train_writer=None):
        fetches_ = {'loss': loss_, 'rec_loss': rec_loss_, 'kl_loss': kl_loss_}

        if mode == 'train':
            x, is_train = train, 1
            fetches_['train_op'] = train_op_
            fetches_['summary'] = merged
        elif mode == 'val':
            assert (print_freq == -1)
            x, is_train = val, None
        elif mode == 'test':
            assert (print_freq == -1)
            x, is_train = test, None

        acc_loss, acc_rec, acc_kl, acc_n = 0.0, 0.0, 0.0, 0.0
        local_t = 0
        global_t = epoch * epoch_t  # only used in train mode
        start_time = time.time()
        kf = get_minibatches_idx(len(x), opt.batch_size, shuffle=True)

        for _, index in kf:
            local_t += 1
            global_t_cyc = global_t % cycle_t
            lr_t = 0.5 * opt.lr * (
                1 + np.cos(float(global_t_cyc) / cycle_t * np.pi))
            global_t += 1
            if mode == 'train':
                if opt.vae_anneal:
                    beta_t = opt.max_beta * np.minimum(
                        (global_t_cyc + 1.) / full_kl_step, 1.)
                else:
                    beta_t = opt.max_beta
            else:
                beta_t = opt.max_beta

            sents_b = [x[i] for i in index]
            sents_b_n = add_noise(sents_b, opt)
            x_b_org = prepare_data_for_rnn(sents_b, opt)  # Batch L
            x_b = prepare_data_for_cnn(sents_b_n, opt)  # Batch L
            feed_t = {
                beta_: beta_t,
                x_: x_b,
                x_org_: x_b_org,
                is_train_: is_train,
                lr_: lr_t
            }
            fetches = sess.run(fetches_, feed_dict=feed_t)

            batch_size = len(index)
            acc_n += batch_size
            acc_loss += fetches['loss'] * batch_size
            acc_rec += fetches['rec_loss'] * batch_size
            acc_kl += fetches['kl_loss'] * batch_size
            if print_freq > 0 and local_t % print_freq == 0:
                print(
                    "%s Iter %d: loss %.4f, rec %.4f, kl %.4f, beta %.4f, lr %.4fe-4, time %.1fs"
                    %
                    (mode, local_t, acc_loss / acc_n, acc_rec / acc_n, acc_kl /
                     acc_n, beta_t, lr_t * 1e4, time.time() - start_time))
                sys.stdout.flush()
            if mode == 'train' and train_writer != None:
                train_writer.add_summary(fetches['summary'], global_t)

        if display_sent > 0:
            index_d = np.random.choice(len(x), opt.batch_size, replace=False)
            sents_d = [x[i] for i in index_d]
            sents_d_n = add_noise(sents_d, opt)
            x_d_org = prepare_data_for_rnn(sents_d, opt)  # Batch L
            x_d = prepare_data_for_cnn(sents_d_n, opt)  # Batch L
            res = sess.run(res_,
                           feed_dict={
                               beta_: beta_t,
                               x_: x_d,
                               x_org_: x_d_org,
                               is_train_: is_train
                           })
            for i in range(display_sent):
                print(
                    "%s Org: " % mode + " ".join([
                        ixtoword[ix]
                        for ix in sents_d[i] if ix != 0 and ix != 2
                    ]))
                if mode == 'train':
                    print(
                        "%s Rec(feedy): " % mode + " ".join([
                            ixtoword[ix] for ix in res['rec_sents_feed_y'][i]
                            if ix != 0 and ix != 2
                        ]))
                print(
                    "%s Rec: " % mode + " ".join([
                        ixtoword[ix]
                        for ix in res['rec_sents'][i] if ix != 0 and ix != 2
                    ]))

        print(
            "%s Epoch %d: loss %.4f, rec %.4f, kl %.4f, beta %.4f, time %.1fs"
            % (mode, epoch, acc_loss / acc_n, acc_rec / acc_n, acc_kl / acc_n,
               beta_t, time.time() - start_time))
        return acc_loss / acc_n, acc_rec / acc_n, acc_kl / acc_n
Beispiel #11
0
def run_model(opt, train_unlab_x, train_lab_x, train_lab, val_unlab_x,
              val_lab_x, val_lab, test, test_y, wordtoix, ixtoword):
    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' + str(
                      (opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        alpha_ = tf.placeholder(tf.float32, shape=())
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_lab_ = tf.placeholder(tf.int32,
                                shape=[opt.dis_batch_size, opt.sent_len])
        y_ = tf.placeholder(tf.float32, shape=[opt.dis_batch_size, 1])
        dp_ratio_ = tf.placeholder(tf.float32, name='dp_ratio_')
        res_, dis_loss_, rec_loss_, loss_, train_op, prob_, acc_ = semi_classifier(
            alpha_, x_, x_org_, x_lab_, y_, dp_ratio_, opt)
        merged = tf.summary.merge_all()

    uidx = 0
    max_val_accuracy = 0.0
    max_test_accuracy = 0.0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    # config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                loader = restore_from_save(t_vars, sess, opt)

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):

            print("Starting epoch %d" % epoch)

            kf = get_minibatches_idx(len(train_unlab_x),
                                     opt.batch_size,
                                     shuffle=True)
            for _, train_index in kf:
                uidx += 1

                if opt.rec_alpha > 0 and uidx > opt.pretrain_step and uidx % opt.rec_decay_freq == 0:
                    opt.rec_alpha -= 0.01
                    print "alpha: " + str(opt.rec_alpha)

                sents = [train_unlab_x[t] for t in train_index]

                lab_index = np.random.choice(len(train_lab),
                                             opt.dis_batch_size,
                                             replace=False)
                lab_sents = [train_lab_x[t] for t in lab_index]
                batch_lab = [train_lab[t] for t in lab_index]
                batch_lab = np.array(batch_lab)
                batch_lab = batch_lab.reshape((len(batch_lab), 1))
                x_batch_lab = prepare_data_for_cnn(lab_sents, opt)

                sents_permutated = add_noise(sents, opt)

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt)  # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt)  # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated,
                                                   opt)  # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated,
                                                   opt,
                                                   is_add_GO=False)  # Batch L

                _, dis_loss, rec_loss, loss, acc = sess.run(
                    [train_op, dis_loss_, rec_loss_, loss_, acc_],
                    feed_dict={
                        alpha_: opt.rec_alpha,
                        x_: x_batch,
                        x_org_: x_batch_org,
                        x_lab_: x_batch_lab,
                        y_: batch_lab,
                        dp_ratio_: opt.dropout_ratio
                    })
                summary = sess.run(merged,
                                   feed_dict={
                                       alpha_: opt.rec_alpha,
                                       x_: x_batch,
                                       x_org_: x_batch_org,
                                       x_lab_: x_batch_lab,
                                       y_: batch_lab,
                                       dp_ratio_: opt.dropout_ratio
                                   })
                train_writer.add_summary(summary, uidx)

                if uidx % opt.print_freq == 0:
                    print(
                        "Iteration %d: dis_loss %f, rec_loss %f, loss %f, acc %f "
                        % (uidx, dis_loss, rec_loss, loss, acc))

                if uidx % opt.valid_freq == 0:
                    #print("Iteration %d: dis_loss %f, rec_loss %f, loss %f " % (uidx, dis_loss, rec_loss, loss))
                    valid_index = np.random.choice(len(val_unlab_x),
                                                   opt.batch_size)
                    val_sents = [val_unlab_x[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(
                            val_sents_permutated, opt, is_add_GO=False)

                    rec_loss_val = sess.run(rec_loss_,
                                            feed_dict={
                                                x_: x_val_batch,
                                                x_org_: x_val_batch_org,
                                                dp_ratio_: 1.0
                                            })
                    print("Validation rec loss %f " % rec_loss_val)

                    kf_val = get_minibatches_idx(len(val_lab_x),
                                                 opt.dis_batch_size,
                                                 shuffle=False)

                    prob_val = []
                    for _, val_ind in kf_val:
                        val_sents = [val_lab_x[t] for t in val_ind]
                        x_val_dis = prepare_data_for_cnn(val_sents, opt)
                        val_y = np.array([val_lab[t]
                                          for t in val_ind]).reshape(
                                              (opt.dis_batch_size, 1))
                        val_prob = sess.run(prob_,
                                            feed_dict={
                                                x_lab_: x_val_dis,
                                                dp_ratio_: 1.0
                                            })
                        for x in val_prob:
                            prob_val.append(x)

                    ##### DON'T UNDERSTAND :error   val_index
                    # probs = []
                    # val_truth = []
                    # for i in range(len(val_lab)):
                    #     val_truth.append(val_lab[i])
                    #     if type(val_index[i]) != int:
                    #         temp = []
                    #         for j in val_index[i]:
                    #             temp.append(prob_val[j])
                    #         aver = sum(temp) * 1.0 / len(temp)
                    #         probs.append(aver)
                    #     else:
                    #         probs.append(prob_val[val_index[i]])

                    probs = []
                    val_truth = []
                    for i in range(len(prob_val)):
                        val_truth.append(val_lab[i])
                        probs.append(prob_val[i])

                    count = 0.0
                    for i in range(len(probs)):
                        p = probs[i]
                        if p > 0.5:
                            if val_truth[i] == 1:
                                count += 1.0
                        else:
                            if val_truth[i] == 0:
                                count += 1.0

                    val_accuracy = count * 1.0 / len(probs)

                    print("Validation accuracy %f " % val_accuracy)

                    summary = sess.run(merged,
                                       feed_dict={
                                           alpha_: opt.rec_alpha,
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org,
                                           x_lab_: x_val_dis,
                                           y_: val_y,
                                           dp_ratio_: 1.0
                                       })
                    test_writer.add_summary(summary, uidx)

                    if val_accuracy >= max_val_accuracy:
                        max_val_accuracy = val_accuracy

                        kf_test = get_minibatches_idx(len(test),
                                                      opt.dis_batch_size,
                                                      shuffle=False)
                        prob_test = []
                        for _, test_ind in kf_test:
                            test_sents = [test[t] for t in test_ind]
                            x_test_batch = prepare_data_for_cnn(
                                test_sents, opt)
                            test_prob = sess.run(prob_,
                                                 feed_dict={
                                                     x_lab_: x_test_batch,
                                                     dp_ratio_: 1.0
                                                 })
                            for x in test_prob:
                                prob_test.append(x)

                        probs = []
                        test_truth = []
                        for i in range(len(prob_test)):
                            test_truth.append(test_y[i])
                            probs.append(prob_test[i])

                        # probs = []
                        # test_truth = []
                        # for i in range(len(test_y)):
                        #     test_truth.append(test_y[i])
                        #     if type(test_index[i]) != int:
                        #         temp = [prob_test[j] for j in test_index[i]]
                        #         aver = sum(temp) * 1.0 / len(temp)
                        #         probs.append(aver)
                        #     else:
                        #         probs.append(prob_test[test_index[i]])

                        count = 0.0
                        for i in range(len(probs)):
                            p = probs[i]
                            if p > 0.5:
                                if test_truth[i] == 1.0:
                                    count += 1.0
                            else:
                                if test_truth[i] == 0.0:
                                    count += 1.0

                        test_accuracy = count * 1.0 / len(probs)

                        print("Test accuracy %f " % test_accuracy)

                        max_test_accuracy = test_accuracy

                def test_input(text):
                    x_input = sent2idx(text, wordtoix, opt)
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_input,
                                       x_org_: x_batch_org
                                   })
                    print "Reconstructed:" + " ".join(
                        [ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    # res = sess.run(res_, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_: 1})
                    # print "Original     :" + " ".join([ixtoword[x] for x in sents[0] if x != 0])
                    # # print "Permutated   :" + " ".join([ixtoword[x] for x in sents_permutated[0] if x != 0])
                    # if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                    #     print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0])
                    # print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    # print "Probs:" + " ".join([ixtoword[res['rec_sents'][0][i]] +'(' +str(np.round(res['all_p'][i],2))+')' for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])

            print(opt.rec_alpha)
            print("Epoch %d: Max Valid accuracy %f" %
                  (epoch, max_val_accuracy))
            print("Epoch %d: Max Test accuracy %f" %
                  (epoch, max_test_accuracy))

            saver.save(sess, opt.save_path, global_step=epoch)
Beispiel #12
0
def run_model(opt, train, val, ixtoword):

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' +
                  str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt)
        merged = tf.summary.merge_all()

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)

    # keep all checkpoints
    saver = tf.train.Saver(max_to_keep=None)

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)
                print('\nload successfully\n')

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        ''' validation '''
        valid_index = np.random.choice(len(val), opt.batch_size)
        val_sents = [val[t] for t in valid_index]

        val_sents_permutated = add_noise(val_sents, opt)

        x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

        d_loss_val = sess.run(d_loss_,
                              feed_dict={
                                  x_: x_val_batch,
                                  x_org_: x_val_batch_org
                              })
        g_loss_val = sess.run(g_loss_,
                              feed_dict={
                                  x_: x_val_batch,
                                  x_org_: x_val_batch_org
                              })

        res = sess.run(res_,
                       feed_dict={
                           x_: x_val_batch,
                           x_org_: x_val_batch_org
                       })
        try:
            print("Validation d_loss %f, g_loss %f  mean_dist %f" %
                  (d_loss_val, g_loss_val, res['mean_dist']))
            print("Sent:" + u' '.join([
                ixtoword[x] for x in res['syn_sent'][0] if x != 0
            ]))  #.encode('utf-8', 'ignore').decode("utf8").strip())
            print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan']))
        except Exception as e:
            print(e)

        # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
        if opt.discrimination:
            print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f']))

        for i in range(268590 // 1000 +
                       1):  # generate 10k sentences # generate 268590
            valid_index = np.random.choice(len(val), opt.batch_size)
            val_sents = [val[t] for t in valid_index]
            val_sents_permutated = add_noise(val_sents, opt)
            x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
            x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
            res = sess.run(res_,
                           feed_dict={
                               x_: x_val_batch,
                               x_org_: x_val_batch_org
                           })
            if i == 0:
                valid_text = res['syn_sent']
            else:
                valid_text = np.concatenate((valid_text, res['syn_sent']), 0)

        valid_text = valid_text[:268590]
        np.savetxt(PATH_TO_SAVE, valid_text, fmt='%i', delimiter=' ')
        print('saved!\n\n\n')
        exit()

        val_set = [prepare_for_bleu(s) for s in val]  #val_sents]
        bleu_prepared = [prepare_for_bleu(s) for s in res['syn_sent']]
        for i in range(len(val_set) // opt.batch_size):
            batch = val_set[i * opt.batch_size:(i + 1) * opt.batch_size]
            [bleu2s, bleu3s, bleu4s] = cal_BLEU(bleu_prepared,
                                                {0: batch})  #val_set})
            print('Val BLEU (2,3,4): ' + ' '.join(
                [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]))
def run_model(opt, train, val, test, wordtoix, ixtoword):


    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())


    uidx = 0
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()



    run_metadata = tf.RunMetadata()


    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)


            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO = False) # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print


                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                #
                if profile:
                    _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata)
                else:
                    _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})

                #pdb.set_trace()

                if uidx % opt.valid_freq == 0:
                    is_train = None
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)


                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    if opt.discrimination:
                        print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f']))
                    print "Val Orig :" + " ".join([ixtoword[x] for x in val_sents[0] if x != 0])
                    #print "Val Perm :" + " ".join([ixtoword[x] for x in val_sents_permutated[0] if x != 0])
                    print "Val Recon:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])

                    # if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    #     print "Org Probs:" + " ".join(
                    #         [ixtoword[x_val_batch_org[0][i]] + '(' + str(np.round(res['all_p'][i], 1)) + ')' for i in
                    #          range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])
                    #     print "Gen Probs:" + " ".join(
                    #         [ixtoword[res['rec_sents'][0][i]] + '(' + str(np.round(res['gen_p'][i], 1)) + ')' for i in
                    #          range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])

                    summary = sess.run(merged, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    test_writer.add_summary(summary, uidx)
                    is_train = True

                def test_input(text):
                    x_input = sent2idx(text, wordtoix, opt)
                    res = sess.run(res_, feed_dict={x_: x_input, x_org_: x_batch_org})
                    print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                if uidx % opt.print_freq == 0:
                    #pdb.set_trace()
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})
                    print "Original     :" + " ".join([ixtoword[x] for x in sents[0] if x != 0])
                    #print "Permutated   :" + " ".join([ixtoword[x] for x in sents_permutated[0] if x != 0])
                    if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                        print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0])
                    print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    # print "Probs:" + " ".join([ixtoword[res['rec_sents'][0][i]] +'(' +str(np.round(res['all_p'][i],2))+')' for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])

                    summary = sess.run(merged, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                        tf.get_default_graph(),
                        run_meta=run_metadata,
                        tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path, global_step=epoch)
Beispiel #14
0
def main():
    #global n_words
    # Prepare training and testing data

    opt = COptions(args)
    opt_t = COptions(args)
    # opt_t.n_hid = opt.n_z

    loadpath = (opt.data_dir + "/" + opt.data_name) #if opt.not_philly else '/hdfs/msrlabs/xiag/pt-data/cons/data_cleaned/twitter_small.p'
    print "loadpath:" + loadpath
    x = cPickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    wordtoix, ixtoword = x[3], x[4]

    if opt.test:
        test_file = opt.data_dir + opt.test_file 
        test = read_test(test_file, wordtoix)
        # test = [ x for x in test if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])]
    # train_filtered = [ x for x in train if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])]
    # val_filtered = [ x for x in val if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])]
    # print ("Train: %d => %d" % (len(train), len(train_filtered)))
    # print ("Val: %d => %d" % (len(val), len(val_filtered)))
    # train, val = train_filtered, val_filtered
    # del train_filtered, val_filtered

    opt.n_words = len(ixtoword) 
    opt_t.n_words = len(ixtoword)
    opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
    opt_t.update_params(args)
    print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
    print dict(opt)
    print('Total words: %d' % opt.n_words)

    # print dict(opt)
    # if opt.model == 'cnn_rnn':
    #     opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
    #     opt_t.update_params(args)
        # print dict(opt_t)


    #for d in ['/gpu:0', '/gpu:1', '/gpu:2', '/gpu:3']:
    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = [tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context)]
            tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len])
            z_ = tf.placeholder(tf.float32, shape=[opt_t.batch_size , opt.n_z * (2 if opt.local_feature else 1)])
            is_train_ = tf.placeholder(tf.bool, name = 'is_train')
            res_1_ = get_features(src_, tgt_, is_train_, opt, opt_t)
            res_2_ = generate_resp(src_, tgt_, z_, is_train_, opt, opt_t)
            merged = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    graph_options=tf.GraphOptions(build_cost_model=1)
    #config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1))
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options)
    # config.gpu_options.per_process_gpu_memory_fraction = 0.70
    #config = tf.ConfigProto(device_count={'GPU':0})
    #config.gpu_options.allow_growth = True

    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()
                t_vars = tf.trainable_variables()  
                #t_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) #tf.trainable_variables()

                # if opt.load_from_pretrain:
                #     d_vars = [var for var in t_vars if var.name.startswith('d_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     g_vars = [var for var in t_vars if var.name.startswith('g_')]
                #     l_vars = [var for var in t_vars if var.name.startswith('l_')]
                #     #restore_from_save(g_vars, sess, opt, prefix = 'g_', load_path=opt.restore_dir + "/save/generator2")
                #     restore_from_save(d_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.global_d)
                #     if opt.local_feature:
                #         restore_from_save(l_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.local_d)
                # else:
                loader = restore_from_save(t_vars, sess, opt, load_path = opt.save_path)


            except Exception as e:
                print 'Error: '+str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d , loss_g = 0, 0

        if opt.test:
            iter_num = np.int(np.floor(len(test)/opt.batch_size))+1
            res_all = []
            val_tgt_all =[]
            for i in range(iter_num):
                test_index = range(i * opt.batch_size,(i+1) * opt.batch_size)
                sents = [test[t%len(test)] for t in test_index]
                for idx in range(opt.n_context,opt.num_turn):
                    src = [[sents[i][idx-turn] for i in range(opt.batch_size)] for turn in range(opt.n_context,0,-1)]
                    tgt = [sents[i][idx] for i in range(opt.batch_size)] 
                    val_tgt_all.extend(tgt)
                    if opt.feed_generated and idx!= opt.n_context:
                        src[-1] = [[x for x in p if x!=0] for p in res_all[-opt.batch_size:]]

                    x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] # Batch L
                    y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) 
                    
                    feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) # do not use False
                    res_1 = sess.run(res_1_, feed_dict=feed)
                    z_all = np.array(res_1['z'])

                    
                    feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, z_: z_all, is_train_: 0}) # do not use False
                    res_2 = sess.run(res_2_, feed_dict=feed)
                    res_all.extend(res_2['syn_sent'])

                    # bp()
   
            val_tgt_all = reshaping(val_tgt_all, opt)
            res_all = reshaping(res_all, opt)
            
            save_path = opt.log_path + '.resp.txt'
            if os.path.exists(save_path):
                os.remove(save_path) 
            for idx in range(len(test)*(opt.num_turn-opt.n_context)):
                with open(save_path, "a") as resp_f:
                    resp_f.write(u' '.join([ixtoword[x] for x in res_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') )
            print ("save to:" + save_path)

            if opt.verbose:
                save_path = opt.log_path + '.tgt.txt'
                if os.path.exists(save_path):
                    os.remove(save_path) 
                for idx in range(len(test)*(opt.num_turn-opt.n_context)):
                    with open(save_path, "a") as tgt_f:
                        tgt_f.write(u' '.join([ixtoword[x] for x in val_tgt_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') )
                print ("save to:" + save_path)

            val_set = [prepare_for_bleu(s) for s in val_tgt_all]
            gen = [prepare_for_bleu(s) for s in res_all]
            [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus)
            etp_score, dist_score = cal_entropy(gen)

            # print save_path
            print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)])
            # print 'Val Rouge: ' + ' '.join([str(round(it,3)) for it in (rouge1,rouge2,rouge3,rouge4)])
            print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])])
            print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])])
            # print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])])
            print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) 
            if opt.embedding_score:
                with open("../../ssd0/consistent_dialog/data/GoogleNews-vectors-negative300.bin.p", 'rb') as pfile:
                    embedding = cPickle.load(pfile)
                rel_score = cal_relevance(gen, val_set, embedding)
                print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])])


            if not opt.global_feature or opt.bit == None: exit(0)

        if opt.test:
            iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 
            for int_idx in range(opt.int_num):
                res_all = []
                z1,z2,z3 = [],[],[]
                val_tgt_all =[]
                for i in range(iter_num):
                    test_index = range(i * opt.batch_size,(i+1) * opt.batch_size)
                    sents = [test[t%len(test)] for t in test_index]
                    for idx in range(opt.n_context,opt.num_turn):
                        src = [[sents[i][idx-turn] for i in range(opt.batch_size)] for turn in range(opt.n_context,0,-1)]
                        tgt = [sents[i][idx] for i in range(opt.batch_size)]
                        val_tgt_all.extend(tgt)
                        if opt.feed_generated and idx!= opt.n_context:
                            src[-1] = [[x for x in p if x!=0] for p in res_all[-opt.batch_size:]]

                        x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] # Batch L
                        y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) 
                        feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) # do not use False
                        res_1 = sess.run(res_1_, feed_dict=feed)
                        z_all = np.array(res_1['z'])
                        z_all[:,opt.bit] = np.array([1.0/np.float(opt.int_num-1) * int_idx for _ in range(opt.batch_size)])
                        
                        feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, z_: z_all, is_train_: 0}) # do not use False
                        res_2 = sess.run(res_2_, feed_dict=feed)
                        res_all.extend(res_2['syn_sent'])
                        z1.extend(res_1['z'])                        
                        z2.extend(z_all)
                        z3.extend(res_2['z_hat'])
                        
                        # bp()

                val_tgt_all = reshaping(val_tgt_all, opt)
                res_all = reshaping(res_all, opt)
                z1 = reshaping(z1, opt)
                z2 = reshaping(z2, opt)
                z3 = reshaping(z3, opt)
                
                save_path = opt.log_path  + 'bit' + str(opt.bit) + '.'+ str(1.0/np.float(opt.int_num-1) * int_idx) +'.int.txt'
                if os.path.exists(save_path):
                    os.remove(save_path) 
                for idx in range(len(test)*(opt.num_turn-opt.n_context)):
                    with open(save_path, "a") as resp_f:
                        resp_f.write(u' '.join([ixtoword[x] for x in res_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') )
                print ("save to:" + save_path)

                save_path_z = opt.log_path  + 'bit' + str(opt.bit) + '.'+ str(1.0/np.float(opt.int_num-1) * int_idx) +'.z.txt'
                if os.path.exists(save_path_z):
                    os.remove(save_path_z) 
                for idx in range(len(test)*(opt.num_turn-opt.n_context)):
                    with open(save_path_z, "a") as myfile:
                        #ary = np.array([z1[idx][opt.bit], z2[idx][opt.bit], z3[idx][opt.bit]])
                        #myfile.write(np.array2string(ary, formatter={'float_kind':lambda x: "%.2f" % x}) + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t'))
                        myfile.write(str(z3[idx][opt.bit]) + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t'))

                
                val_set = [prepare_for_bleu(s) for s in val_tgt_all]
                gen = [prepare_for_bleu(s) for s in res_all]
                [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus)
                etp_score, dist_score = cal_entropy(gen)

                print save_path
                print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)])
                # print 'Val Rouge: ' + ' '.join([str(round(it,3)) for it in (rouge1,rouge2,rouge3,rouge4)])
                print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])])
                print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])])
                # print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])])
                print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) 
Beispiel #15
0
def run_model(opt, train, val, ixtoword):

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' +
                  str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        # is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)
                print('\nload successfully\n')

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        # for i in range(34):
        #     valid_index = np.random.choice(
        #         len(val), opt.batch_size)
        #     val_sents = [val[t] for t in valid_index]
        #     val_sents_permutated = add_noise(val_sents, opt)
        #     x_val_batch = prepare_data_for_cnn(
        #         val_sents_permutated, opt)
        #     x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
        #     res = sess.run(res_, feed_dict={
        #                     x_: x_val_batch, x_org_: x_val_batch_org})
        #     if i == 0:
        #         valid_text = res['syn_sent']
        #     else:
        #         valid_text = np.concatenate(
        #             (valid_text, res['syn_sent']), 0)

        # np.savetxt('./text_news/vae_words.txt', valid_text, fmt='%i', delimiter=' ')
        # pdb.set_trace()

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])
                x_batch = prepare_data_for_cnn(sents_permutated,
                                               opt)  # Batch L
                x_batch_org = prepare_data_for_rnn(sents, opt)
                d_loss = 0
                g_loss = 0
                if profile:
                    if uidx % opt.dis_steps == 0:
                        _, d_loss = sess.run(
                            [dis_op, d_loss_],
                            feed_dict={
                                x_: x_batch,
                                x_org_: x_batch_org
                            },
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                    if uidx % opt.gen_steps == 0:
                        _, g_loss = sess.run(
                            [gen_op, g_loss_],
                            feed_dict={
                                x_: x_batch,
                                x_org_: x_batch_org
                            },
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                else:
                    if uidx % opt.dis_steps == 0:
                        _, d_loss = sess.run([dis_op, d_loss_],
                                             feed_dict={
                                                 x_: x_batch,
                                                 x_org_: x_batch_org
                                             })
                    if uidx % opt.gen_steps == 0:
                        _, g_loss = sess.run([gen_op, g_loss_],
                                             feed_dict={
                                                 x_: x_batch,
                                                 x_org_: x_batch_org
                                             })
                ''' validation '''
                if uidx % opt.valid_freq == 0:

                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    x_val_batch = prepare_data_for_cnn(val_sents_permutated,
                                                       opt)
                    x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    d_loss_val = sess.run(d_loss_,
                                          feed_dict={
                                              x_: x_val_batch,
                                              x_org_: x_val_batch_org
                                          })
                    g_loss_val = sess.run(g_loss_,
                                          feed_dict={
                                              x_: x_val_batch,
                                              x_org_: x_val_batch_org
                                          })

                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org
                                   })
                    print("Validation d_loss %f, g_loss %f  mean_dist %f" %
                          (d_loss_val, g_loss_val, res['mean_dist']))
                    print("Sent:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]))  #.encode('utf-8', 'ignore').decode("utf8").strip())
                    print("MMD loss %f, GAN loss %f" %
                          (res['mmd'], res['gan']))
                    # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))

                    for i in range(4):
                        valid_index = np.random.choice(len(val),
                                                       opt.batch_size)
                        val_sents = [val[t] for t in valid_index]
                        val_sents_permutated = add_noise(val_sents, opt)
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
                        res = sess.run(res_,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                        if i == 0:
                            valid_text = res['syn_sent']
                        else:
                            valid_text = np.concatenate(
                                (valid_text, res['syn_sent']), 0)

                    np.savetxt('./text_news/syn_val_words.txt',
                               valid_text,
                               fmt='%i',
                               delimiter=' ')

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['syn_sent']],
                        {0: val_set})
                    print('Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3))
                         for it in (bleu2s, bleu3s, bleu4s)]))

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                    test_writer.add_summary(summary, uidx)
Beispiel #16
0
def main():
    # Prepare training and testing data

    loadpath = "./data/"

    src_file = loadpath + "Pairs2M.src.num"
    tgt_file = loadpath + "Pairs2M.tgt.num"
    dic_file = loadpath + "Pairs2M.reddit.dic"

    opt = Options()
    opt_t = Options()

    train, val, test, wordtoix, ixtoword = read_pair_data_full(
        src_file,
        tgt_file,
        dic_file,
        max_num=opt.data_size,
        p_f=loadpath + 'demo.p')
    train = [
        x for x in train
        if 2 < len(x[1]) < opt.maxlen - 4 and 2 < len(x[0]) < opt_t.maxlen - 4
    ]
    val = [
        x for x in val
        if 2 < len(x[1]) < opt.maxlen - 4 and 2 < len(x[0]) < opt_t.maxlen - 4
    ]

    if TEST_FLAG:
        test = [test]
        opt.test_freq = 1

    opt.n_words = len(ixtoword)
    opt_t.n_words = len(ixtoword)
    print dict(opt)
    if opt.model == 'cnn_rnn':
        opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
        opt_t.update_params()
        print dict(opt_t)

    print('Total words: %d' % opt.n_words)

    # load w2v
    if os.path.exists(opt.embedding_path_lime):
        with open(opt.embedding_path_lime, 'rb') as pfile:
            embedding = cPickle.load(pfile)
    else:
        w2v = gensim.models.KeyedVectors.load_word2vec_format(
            opt.embedding_path, binary=True)
        embedding = {
            i: copy.deepcopy(w2v[ixtoword[i]])
            for i in range(opt.n_words) if ixtoword[i] in w2v
        }
        with open(opt.embedding_path_lime, 'wb') as pfile:
            cPickle.dump(embedding, pfile, protocol=cPickle.HIGHEST_PROTOCOL)

    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = tf.placeholder(tf.int32,
                                  shape=[opt.batch_size, opt.sent_len])
            tgt_ = tf.placeholder(tf.int32,
                                  shape=[opt_t.batch_size, opt_t.sent_len])
            res_, gan_cost_d_, train_op_d, gan_cost_g_, train_op_g = dialog_gan(
                src_, tgt_, opt, opt_t)
            merged = tf.summary.merge_all()

    uidx = 0
    graph_options = tf.GraphOptions(build_cost_model=1)
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=graph_options)
    config.gpu_options.per_process_gpu_memory_fraction = 0.95

    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:

                t_vars = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES)  #tf.trainable_variables()

                if opt.load_from_ae:
                    save_keys = tensors_key_in_file(
                        opt.load_path)  #t_var g_W:0    key: W
                    ss = [
                        var for var in t_vars
                        if var.name[2:][:-2] in save_keys.keys()
                    ]
                    ss = [
                        var.name[2:] for var in ss
                        if var.get_shape() == save_keys[var.name[2:][:-2]]
                    ]
                    cc = {
                        var.name[2:][:-2]: var
                        for var in t_vars if var.name[2:] in ss
                    }

                    loader = tf.train.Saver(var_list=cc)
                    loader.restore(sess, opt.load_path)

                    print("Loading variables from '%s'." % opt.load_path)
                    print(
                        "Loaded variables:" + " ".join(
                            [var.name
                             for var in t_vars if var.name[2:] in ss]))
                else:
                    save_keys = tensors_key_in_file(opt.load_path)
                    ss = [
                        var for var in t_vars
                        if var.name[:-2] in save_keys.keys()
                    ]
                    ss = [
                        var.name for var in ss
                        if var.get_shape() == save_keys[var.name[:-2]]
                    ]
                    loader = tf.train.Saver(
                        var_list=[var for var in t_vars if var.name in ss])
                    loader.restore(sess, opt.load_path)
                    print("Loading variables from '%s'." % opt.load_path)
                    print("Loaded variables:" + str(ss))
                    # load reverse model
                    try:
                        save_keys = tensors_key_in_file('./save/rev_model')
                        ss = [
                            var for var in t_vars
                            if var.name[:-2] in save_keys.keys()
                            and 'g_rev_' in var.name
                        ]
                        ss = [
                            var.name for var in ss
                            if var.get_shape() == save_keys[var.name[:-2]]
                        ]
                        loader = tf.train.Saver(
                            var_list=[var for var in t_vars if var.name in ss])
                        loader.restore(sess, './save/rev_model')
                        print(
                            "Loading reverse variables from ./save/rev_model")
                        print("Loaded variables:" + str(ss))
                    except Exception as e:
                        print("No reverse model loaded")

            except Exception as e:
                print 'Error: ' + str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d, loss_g = 0, 0
        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:

                uidx += 1

                if uidx % opt.test_freq == 1:
                    iter_num = np.int(np.floor(len(test) / opt.batch_size)) + 1
                    res_all, test_tgt_all = [], []

                    for i in range(iter_num):
                        test_index = range(i * opt.batch_size,
                                           (i + 1) * opt.batch_size)
                        test_tgt, test_src = zip(
                            *[test[t % len(test)] for t in test_index])
                        test_tgt_all.extend(test_tgt)
                        x_batch = prepare_data_for_cnn(test_src, opt)
                        y_batch = prepare_data_for_rnn(
                            test_tgt, opt_t, is_add_GO=False
                        ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(
                            test_tgt, opt_t)
                        feed = {src_: x_batch, tgt_: y_batch}
                        res = sess.run(res_, feed_dict=feed)
                        res_all.extend(res['syn_sent'])

                    test_set = [prepare_for_bleu(s) for s in test_tgt_all]
                    gen = [prepare_for_bleu(s) for s in res_all]
                    [bleu1s, bleu2s, bleu3s,
                     bleu4s] = cal_BLEU_4(gen, {0: test_set},
                                          is_corpus=opt.is_corpus)
                    [rouge1, rouge2, rouge3, rouge4, rougeL,
                     rouges] = cal_ROUGE(gen, {0: test_set},
                                         is_corpus=opt.is_corpus)
                    etp_score, dist_score = cal_entropy(gen)
                    bleu_nltk = cal_BLEU_4_nltk(gen,
                                                test_set,
                                                is_corpus=opt.is_corpus)
                    rel_score = cal_relevance(gen, test_set, embedding)

                    print 'Test BLEU: ' + ' '.join([
                        str(round(it, 3))
                        for it in (bleu_nltk, bleu1s, bleu2s, bleu3s, bleu4s)
                    ])
                    print 'Test Rouge: ' + ' '.join([
                        str(round(it, 3))
                        for it in (rouge1, rouge2, rouge3, rouge4)
                    ])
                    print 'Test Entropy: ' + ' '.join([
                        str(round(it, 3))
                        for it in (etp_score[0], etp_score[1], etp_score[2],
                                   etp_score[3])
                    ])
                    print 'Test Diversity: ' + ' '.join([
                        str(round(it, 3))
                        for it in (dist_score[0], dist_score[1], dist_score[2],
                                   dist_score[3])
                    ])
                    print 'Test Relevance(G,A,E): ' + ' '.join([
                        str(round(it, 3))
                        for it in (rel_score[0], rel_score[1], rel_score[2])
                    ])
                    print 'Test Avg. length: ' + str(
                        round(
                            np.mean([
                                len([y for y in x if y != 0]) for x in res_all
                            ]), 3))
                    print ''

                    if TEST_FLAG:
                        exit()

                tgt, src = zip(*[train[t] for t in train_index])
                x_batch = prepare_data_for_cnn(src, opt)  # Batch L

                y_batch = prepare_data_for_rnn(
                    tgt, opt_t, is_add_GO=False
                ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(
                    tgt, opt_t)

                feed = {src_: x_batch, tgt_: y_batch}

                if uidx % opt.d_freq == 1:
                    if profile:
                        _, loss_d = sess.run(
                            [train_op_d, gan_cost_d_],
                            feed_dict=feed,
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                    else:
                        _, loss_d = sess.run([train_op_d, gan_cost_d_],
                                             feed_dict=feed)

                if uidx % opt.g_freq == 1:
                    if profile:
                        _, loss_g = sess.run(
                            [train_op_g, gan_cost_g_],
                            feed_dict=feed,
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                    else:
                        _, loss_g = sess.run([train_op_g, gan_cost_g_],
                                             feed_dict=feed)

                if profile:
                    tf.contrib.tfprof.model_analyzer.print_model_analysis(
                        tf.get_default_graph(),
                        run_meta=run_metadata,
                        tfprof_options=tf.contrib.tfprof.model_analyzer.
                        PRINT_ALL_TIMING_MEMORY)
                    exit(0)

                if uidx % opt.valid_freq == 1:
                    VALID_SIZE = 1024
                    valid_multiplier = np.int(
                        np.floor(VALID_SIZE / opt.batch_size))
                    res_all, val_tgt_all, loss_val_d_all, loss_val_g_all = [], [], [], []
                    for val_step in range(valid_multiplier):
                        valid_index = np.random.choice(len(val),
                                                       opt.batch_size)
                        val_tgt, val_src = zip(*[val[t] for t in valid_index])
                        val_tgt_all.extend(val_tgt)
                        x_val_batch = prepare_data_for_cnn(val_src,
                                                           opt)  # Batch L

                        y_val_batch = prepare_data_for_rnn(
                            val_tgt, opt_t, is_add_GO=False
                        ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(
                            val_tgt, opt_t)

                        feed_val = {src_: x_val_batch, tgt_: y_val_batch}
                        loss_val_d, loss_val_g = sess.run(
                            [gan_cost_d_, gan_cost_g_], feed_dict=feed_val)
                        loss_val_d_all.append(loss_val_d)
                        loss_val_g_all.append(loss_val_g)
                        res = sess.run(res_, feed_dict=feed_val)
                        res_all.extend(res['syn_sent'])

                    print("Validation: loss D %f loss G %f " %
                          (np.mean(loss_val_d_all), np.mean(loss_val_g_all)))
                    #print "Val Perm :" + " ".join([ixtoword[x] for x in val_src_permutated[0] if x != 0])
                    print "Val Source:" + u' '.join([
                        ixtoword[x] for x in val_src[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Val Target :" + u' '.join([
                        ixtoword[x] for x in val_tgt[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Val Generated:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]).encode('utf-8').strip()
                    print ""

                    val_set = [prepare_for_bleu(s) for s in val_tgt_all]
                    gen = [prepare_for_bleu(s) for s in res_all]

                    [bleu1s, bleu2s, bleu3s,
                     bleu4s] = cal_BLEU_4(gen, {0: val_set},
                                          is_corpus=opt.is_corpus)
                    [rouge1, rouge2, rouge3, rouge4, rougeL,
                     rouges] = cal_ROUGE(gen, {0: val_set},
                                         is_corpus=opt.is_corpus)
                    etp_score, dist_score = cal_entropy(gen)
                    bleu_nltk = cal_BLEU_4_nltk(gen,
                                                val_set,
                                                is_corpus=opt.is_corpus)
                    rel_score = cal_relevance(gen, val_set, embedding)

                    print 'Val BLEU: ' + ' '.join([
                        str(round(it, 3))
                        for it in (bleu_nltk, bleu1s, bleu2s, bleu3s, bleu4s)
                    ])
                    print 'Val Rouge: ' + ' '.join([
                        str(round(it, 3))
                        for it in (rouge1, rouge2, rouge3, rouge4)
                    ])
                    print 'Val Entropy: ' + ' '.join([
                        str(round(it, 3))
                        for it in (etp_score[0], etp_score[1], etp_score[2],
                                   etp_score[3])
                    ])
                    print 'Val Diversity: ' + ' '.join([
                        str(round(it, 3))
                        for it in (dist_score[0], dist_score[1], dist_score[2],
                                   dist_score[3])
                    ])
                    print 'Val Relevance(G,A,E): ' + ' '.join([
                        str(round(it, 3))
                        for it in (rel_score[0], rel_score[1], rel_score[2])
                    ])
                    print 'Val Avg. length: ' + str(
                        round(
                            np.mean([
                                len([y for y in x if y != 0]) for x in res_all
                            ]), 3))
                    print ""
                    summary = sess.run(merged, feed_dict=feed_val)
                    summary2 = tf.Summary(value=[
                        tf.Summary.Value(tag="bleu-2", simple_value=bleu2s),
                        tf.Summary.Value(tag="rouge-2", simple_value=rouge2),
                        tf.Summary.Value(tag="etp-4",
                                         simple_value=etp_score[3])
                    ])

                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary2, uidx)

                if uidx % opt.print_freq == 1:
                    print("Iteration %d: loss D %f loss G %f" %
                          (uidx, loss_d, loss_g))

                    res = sess.run(res_, feed_dict=feed)

                    if opt.grad_penalty:
                        print "grad_penalty: " + str(res['gp'])
                    print "Source:" + u' '.join([
                        ixtoword[x] for x in x_batch[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Target:" + u' '.join([
                        ixtoword[x] for x in y_batch[0] if x != 0
                    ]).encode('utf-8').strip()
                    print "Generated:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]).encode('utf-8').strip()
                    print ""

                    sys.stdout.flush()
                    summary = sess.run(merged, feed_dict=feed)
                    train_writer.add_summary(summary, uidx)

                if uidx % opt.save_freq == 1:
                    saver.save(sess, opt.save_path)
Beispiel #17
0
def run_model(opt, train, val, ixtoword):

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' + str(
                      (opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    # config.gpu_options.per_process_gpu_memory_fraction = 0.8
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)
                print('Load pretrain successfully')

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt)  # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt)  # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated,
                                                   opt)  # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated,
                                                   opt,
                                                   is_add_GO=False)  # Batch L

                if profile:
                    _, loss = sess.run(
                        [train_op, loss_],
                        feed_dict={
                            x_: x_batch,
                            x_org_: x_batch_org,
                            is_train_: 1
                        },
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                else:
                    _, loss = sess.run([train_op, loss_],
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org,
                                           is_train_: 1
                                       })

                if uidx % opt.valid_freq == 0:
                    is_train = None
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(
                            val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_,
                                        feed_dict={
                                            x_: x_val_batch,
                                            x_org_: x_val_batch_org,
                                            is_train_: is_train
                                        })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org,
                                       is_train_: is_train
                                   })
                    np.savetxt(opt.save_txt + '/rec_val_words.txt',
                               res['rec_sents'],
                               fmt='%i',
                               delimiter=' ')
                    try:
                        print("Orig:" + u' '.join([
                            ixtoword[x]
                            for x in x_val_batch_org[0] if x != 0 and x != 1
                        ]))  #.encode('utf-8', 'ignore').strip()
                        print("Sent:" + u' '.join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ]))  #.encode('utf-8', 'ignore').strip()
                    except:
                        pass
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org,
                                           is_train_: is_train
                                       })
                    test_writer.add_summary(summary, uidx)
                    is_train = True

                if uidx % opt.print_freq == 1:
                    #pdb.set_trace()
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org,
                                       is_train_: 1
                                   })
                    np.savetxt(opt.save_txt + '/rec_train_words.txt',
                               res['rec_sents'],
                               fmt='%i',
                               delimiter=' ')
                    try:
                        print("Orig:" + u' '.join([
                            ixtoword[x]
                            for x in x_batch_org[0] if x != 0 and x != 1
                        ]))  #.encode('utf-8').strip()
                        print("Sent:" + u' '.join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ]))  #.encode('utf-8').strip()
                    except:
                        pass
                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org,
                                           is_train_: 1
                                       })
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                            tf.get_default_graph(),
                            run_meta=run_metadata,
                            tfprof_options=tf.contrib.tfprof.model_analyzer.
                            PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path, global_step=epoch)
Beispiel #18
0
def main():
    
    

    opt = COptions(args)
    opt_t = COptions(args)
    

    loadpath = (opt.data_dir + "/" + opt.data_name) 
    print "loadpath:" + loadpath
    x = cPickle.load(open(loadpath, "rb"))
    train, val, test = x[0], x[1], x[2]
    wordtoix, ixtoword = x[3], x[4]

    if opt.test:
        test_file = opt.data_dir + opt.test_file 
        test = read_test(test_file, wordtoix)
        
    opt.n_words = len(ixtoword) 
    opt_t.n_words = len(ixtoword)
    opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1
    opt_t.update_params(args)
    print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
    print dict(opt)
    print('Total words: %d' % opt.n_words)

  
    for d in ['/gpu:0']:
        with tf.device(d):
            src_ = [tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context)]
            tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len])
            
            is_train_ = tf.placeholder(tf.bool, name = 'is_train')
            res_1_ = get_features(src_, tgt_, is_train_, opt, opt_t)
            merged = tf.summary.merge_all()

    uidx = 0
    graph_options=tf.GraphOptions(build_cost_model=1)
    
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options)
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:       
                t_vars = tf.trainable_variables()  
                if opt.load_from_pretrain:
                    d_vars = [var for var in t_vars if var.name.startswith('d_')]
                    l_vars = [var for var in t_vars if var.name.startswith('l_')]
                    restore_from_save(d_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.global_d)
                    if opt.local_feature:
                        restore_from_save(l_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.local_d)
                else:
                    loader = restore_from_save(t_vars, sess, opt, load_path = opt.save_path)

            except Exception as e:
                print 'Error: '+str(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        loss_d , loss_g = 0, 0

        if opt.test:
            iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 
            z_all, z_all_l = [], []
            for i in range(iter_num):
                test_index = range(i * opt.batch_size,(i+1) * opt.batch_size)
                sents = [test[t%len(test)] for t in test_index]
                src = [[sents[i][0] for i in range(opt.batch_size)]]
                tgt = [sents[i][0] for i in range(opt.batch_size)]
                x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] 
                print "Source:" + u' '.join([ixtoword[x] for s in x_batch for x in s[0] if x != 0]).encode('utf-8').strip()
                y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) 
                feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) 
                res_1 = sess.run(res_1_, feed_dict=feed)
                z_all.extend(res_1['z'])  
                z_all_l.extend(res_1['z_l'])                        

            save_path_z = opt.log_path + '.global.z.txt'
            print save_path_z
            if os.path.exists(save_path_z):
                os.remove(save_path_z) 
            with open(save_path_z, "a") as myfile:
                for line in z_all[:len(test)]:
                    for z_it in line:
                        myfile.write(str(z_it) + '\t')
                    myfile.write('\n')
            
            save_path_z = opt.log_path + '.local.z.txt'
            print save_path_z
            if os.path.exists(save_path_z):
                os.remove(save_path_z) 
            with open(save_path_z, "a") as myfile:
                for line in z_all_l[:len(test)]:
                    for z_it in line:
                        myfile.write(str(z_it) + '\t')
                    myfile.write('\n')