示例#1
0
def run_model(opt, train, val, ixtoword):

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' +
                  str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        # is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)
                print('\nload successfully\n')

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        # for i in range(34):
        #     valid_index = np.random.choice(
        #         len(val), opt.batch_size)
        #     val_sents = [val[t] for t in valid_index]
        #     val_sents_permutated = add_noise(val_sents, opt)
        #     x_val_batch = prepare_data_for_cnn(
        #         val_sents_permutated, opt)
        #     x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
        #     res = sess.run(res_, feed_dict={
        #                     x_: x_val_batch, x_org_: x_val_batch_org})
        #     if i == 0:
        #         valid_text = res['syn_sent']
        #     else:
        #         valid_text = np.concatenate(
        #             (valid_text, res['syn_sent']), 0)

        # np.savetxt('./text_news/vae_words.txt', valid_text, fmt='%i', delimiter=' ')
        # pdb.set_trace()

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])
                x_batch = prepare_data_for_cnn(sents_permutated,
                                               opt)  # Batch L
                x_batch_org = prepare_data_for_rnn(sents, opt)
                d_loss = 0
                g_loss = 0
                if profile:
                    if uidx % opt.dis_steps == 0:
                        _, d_loss = sess.run(
                            [dis_op, d_loss_],
                            feed_dict={
                                x_: x_batch,
                                x_org_: x_batch_org
                            },
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                    if uidx % opt.gen_steps == 0:
                        _, g_loss = sess.run(
                            [gen_op, g_loss_],
                            feed_dict={
                                x_: x_batch,
                                x_org_: x_batch_org
                            },
                            options=tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                else:
                    if uidx % opt.dis_steps == 0:
                        _, d_loss = sess.run([dis_op, d_loss_],
                                             feed_dict={
                                                 x_: x_batch,
                                                 x_org_: x_batch_org
                                             })
                    if uidx % opt.gen_steps == 0:
                        _, g_loss = sess.run([gen_op, g_loss_],
                                             feed_dict={
                                                 x_: x_batch,
                                                 x_org_: x_batch_org
                                             })
                ''' validation '''
                if uidx % opt.valid_freq == 0:

                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    x_val_batch = prepare_data_for_cnn(val_sents_permutated,
                                                       opt)
                    x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    d_loss_val = sess.run(d_loss_,
                                          feed_dict={
                                              x_: x_val_batch,
                                              x_org_: x_val_batch_org
                                          })
                    g_loss_val = sess.run(g_loss_,
                                          feed_dict={
                                              x_: x_val_batch,
                                              x_org_: x_val_batch_org
                                          })

                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org
                                   })
                    print("Validation d_loss %f, g_loss %f  mean_dist %f" %
                          (d_loss_val, g_loss_val, res['mean_dist']))
                    print("Sent:" + u' '.join([
                        ixtoword[x] for x in res['syn_sent'][0] if x != 0
                    ]))  #.encode('utf-8', 'ignore').decode("utf8").strip())
                    print("MMD loss %f, GAN loss %f" %
                          (res['mmd'], res['gan']))
                    # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))

                    for i in range(4):
                        valid_index = np.random.choice(len(val),
                                                       opt.batch_size)
                        val_sents = [val[t] for t in valid_index]
                        val_sents_permutated = add_noise(val_sents, opt)
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
                        res = sess.run(res_,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                        if i == 0:
                            valid_text = res['syn_sent']
                        else:
                            valid_text = np.concatenate(
                                (valid_text, res['syn_sent']), 0)

                    np.savetxt('./text_news/syn_val_words.txt',
                               valid_text,
                               fmt='%i',
                               delimiter=' ')

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['syn_sent']],
                        {0: val_set})
                    print('Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3))
                         for it in (bleu2s, bleu3s, bleu4s)]))

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                    test_writer.add_summary(summary, uidx)
示例#2
0
def run_model(opt, train, val, ixtoword):


    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:0'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())


    uidx = 0
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()



    run_metadata = tf.RunMetadata()


    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)


            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("\nStarting epoch %d\n" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                print "\rIter: %d" % uidx,
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])
                x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L
                if x_batch.shape[0] == opt.batch_size:
                    d_loss = 0
                    g_loss = 0
                    if profile:
                        if uidx % opt.dis_steps == 0:
                            _, d_loss = sess.run([dis_op, d_loss_], feed_dict={x_: x_batch},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata)
                        if uidx % opt.gen_steps == 0:
                            _, g_loss = sess.run([gen_op, g_loss_], feed_dict={x_: x_batch},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata)
                    else:
                        if uidx % opt.dis_steps == 0:
                            _, d_loss = sess.run([dis_op, d_loss_], feed_dict={x_: x_batch})
                        if uidx % opt.gen_steps == 0:
                            _, g_loss = sess.run([gen_op, g_loss_], feed_dict={x_: x_batch})





                if uidx % opt.valid_freq == 0:
                    is_train = True
                    # print('Valid Size:', len(val))
                    valid_index = np.random.choice(len(val), opt.batch_size)

                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)

                    d_loss_val = sess.run(d_loss_, feed_dict={x_: x_val_batch})
                    g_loss_val = sess.run(g_loss_, feed_dict={x_: x_val_batch})


                    res = sess.run(res_, feed_dict={x_: x_val_batch})
                    print("Validation d_loss %f, g_loss %f  mean_dist %f" % (d_loss_val, g_loss_val, res['mean_dist']))
                    print "Sent:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()
                    print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan']))
                    np.savetxt('./text/rec_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
                    if opt.discrimination:
                        print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f']))


                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['syn_sent']], {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])

                    summary = sess.run(merged, feed_dict={x_: x_val_batch})
                    test_writer.add_summary(summary, uidx)





                if uidx % opt.print_freq == 0:
                    #pdb.set_trace()
                    res = sess.run(res_, feed_dict={x_: x_batch})
                    median_dis = np.sqrt(np.median([((x-y)**2).sum() for x in res['real_f'] for y in res['real_f']]))
                    print("Iteration %d: d_loss %f, g_loss %f, mean_dist %f, realdist median %f" % (uidx, d_loss, g_loss, res['mean_dist'], median_dis))
                    np.savetxt('./text/rec_train_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
                    print "Sent:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()

                    summary = sess.run(merged, feed_dict={x_: x_batch})
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                        tf.get_default_graph(),
                        run_meta=run_metadata,
                        tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path, global_step=epoch)
示例#3
0
                save_path = saver.save(sess, logPath + timestr + ".ckpt")
                logger.info('Model saved in file: %s' % save_path)

                #data.save('data-' +         timestr)
                #dis.save('discriminator-' + timestr)
                #gen.save('generator-' +     timestr)

                logger.info('Done ...')

            if np.mod(epoch, validFreq) == 0:
                val_z = np.random.uniform(
                    -1, 1, [batch_size, input_dim * len(window)])
                #val_z = tf.random_uniform([batch_size, input_dim * len(window)], minval=-1, maxval=1)
                predset = sess.run([result3], feed_dict={x: train_x, z: val_z})
                predset = predset[0]
                [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                    predset, test, data, ngram, debug
                )  # Check def of this func why need to pass <data> object

                logger.info('Valid BLEU2 = {}, BLEU3 = {}, BLEU4 = {}'.format(
                    bleu2s, bleu3s, bleu4s))
                print('Valid BLEU (2, 3, 4): ' + ' '.join(
                    [str(round(x, 3)) for x in (bleu2s, bleu3s, bleu4s)]))

                #print ('Valid KDE_INPUT = {} and KDE = {}'.format(kde_input, kde))
        print('epoch {} finished, total time left {}, this epoch {}'.format(
            epoch,
            time.time() - total_start_time,
            time.time() - epoch_start_time))
示例#4
0
def main():
    # global n_words
    # Prepare training and testing data
    # loadpath = "./data/three_corpus_small.p"
    loadpath = "./data/hotel_reviews.p"
    x = cPickle.load(open(loadpath, "rb"))
    train, val = x[0], x[1]
    wordtoix, ixtoword = x[2], x[3]

    opt = Options()
    opt.n_words = len(ixtoword) + 1
    ixtoword[opt.n_words - 1] = 'GO_'
    print dict(opt)
    print('Total words: %d' % opt.n_words)

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' + str(
                      (opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        res_, loss_, train_op = auto_encoder(x_, x_org_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    # tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    # writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    # config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                # pdb.set_trace()

                t_vars = tf.trainable_variables()
                # print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                if opt.substitution == 's':
                    sents_permutated = substitute_sent(sents, opt)
                elif opt.substitution == 'p':
                    sents_permutated = permutate_sent(sents, opt)
                elif opt.substitution == 'a':
                    sents_permutated = add_sent(sents, opt)
                elif opt.substitution == 'd':
                    sents_permutated = delete_sent(sents, opt)
                else:
                    sents_permutated = sents

                # sents[0] = np.random.permutation(sents[0])

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt)  # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt)  # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated,
                                                   opt)  # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated,
                                                   opt,
                                                   is_add_GO=False)  # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print

                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                _, loss = sess.run([train_op, loss_],
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org
                                   })

                if uidx % opt.valid_freq == 0:
                    opt.is_train = False
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    if opt.substitution == 's':
                        val_sents_permutated = substitute_sent(val_sents, opt)
                    elif opt.substitution == 'p':
                        val_sents_permutated = permutate_sent(val_sents, opt)
                    elif opt.substitution == 'a':
                        val_sents_permutated = add_sent(val_sents, opt)
                    elif opt.substitution == 'd':
                        val_sents_permutated = delete_sent(val_sents, opt)
                    else:
                        val_sents_permutated = sents

                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(
                            val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_,
                                        feed_dict={
                                            x_: x_val_batch,
                                            x_org_: x_val_batch_org
                                        })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org
                                   })
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))
                    print "Val Orig :" + " ".join(
                        [ixtoword[x] for x in val_sents[0] if x != 0])
                    print "Val Perm :" + " ".join([
                        ixtoword[x] for x in val_sents_permutated[0] if x != 0
                    ])
                    print "Val Recon:" + " ".join(
                        [ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['rec_sents']],
                        {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])
                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org
                                       })
                    test_writer.add_summary(summary, uidx)
                    opt.is_train = True

                if uidx % opt.print_freq == 0:
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org
                                   })
                    print "Original     :" + " ".join(
                        [ixtoword[x] for x in sents[0] if x != 0])
                    print "Permutated   :" + " ".join(
                        [ixtoword[x] for x in sents_permutated[0] if x != 0])
                    if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                        print "Reconstructed:" + " ".join([
                            ixtoword[x]
                            for x in res['rec_sents_feed_y'][0] if x != 0
                        ])
                    print "Reconstructed:" + " ".join(
                        [ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org
                                       })
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]

            saver.save(sess, opt.save_path, global_step=epoch)
示例#5
0
def run_model(opt, train, val, test, wordtoix, ixtoword):

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt)
        merged = tf.summary.merge_all()
        summary_ext = tf.Summary()

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                t_vars = tf.trainable_variables()
                loader = restore_from_save(t_vars, sess, opt)
            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt)  # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt)  # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated,
                                                   opt)  # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated,
                                                   opt,
                                                   is_add_GO=False)  # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print

                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                #
                if profile:
                    _, loss = sess.run(
                        [train_op, loss_],
                        feed_dict={
                            x_: x_batch,
                            x_org_: x_batch_org,
                            is_train_: 1
                        },
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                else:
                    _, loss = sess.run([train_op, loss_],
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org,
                                           is_train_: 1
                                       })

                #pdb.set_trace()

                if uidx % opt.valid_freq == 0:
                    is_train = None
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)

                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(
                            val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(
                            val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_,
                                        feed_dict={
                                            x_: x_val_batch,
                                            x_org_: x_val_batch_org,
                                            is_train_: is_train
                                        })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_val_batch,
                                       x_org_: x_val_batch_org,
                                       is_train_: is_train
                                   })
                    if opt.discrimination:
                        print("Real Prob %f Fake Prob %f" %
                              (res['prob_r'], res['prob_f']))

                    if opt.char:
                        print "Val Orig :" + "".join(
                            [ixtoword[x] for x in val_sents[0] if x != 0])
                        print "Val Perm :" + "".join([
                            ixtoword[x]
                            for x in val_sents_permutated[0] if x != 0
                        ])
                        print "Val Recon:" + "".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])
                        # print "Val Recon one hot:" + "".join([ixtoword[x] for x in res['rec_sents_one_hot'][0] if x != 0])
                    else:
                        print "Val Orig :" + " ".join(
                            [ixtoword[x] for x in val_sents[0] if x != 0])
                        print "Val Perm :" + " ".join([
                            ixtoword[x]
                            for x in val_sents_permutated[0] if x != 0
                        ])
                        print "Val Recon:" + " ".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU(
                        [prepare_for_bleu(s) for s in res['rec_sents']],
                        {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join(
                        [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])

                    val_set_char = [
                        prepare_for_cer(s, ixtoword) for s in val_sents
                    ]
                    cer = cal_cer([
                        prepare_for_cer(s, ixtoword) for s in res['rec_sents']
                    ], val_set_char)
                    print 'Val CER: ' + str(round(cer, 3))
                    # summary_ext.Value(tag='CER', simple_value=cer)
                    summary_ext = tf.Summary(
                        value=[tf.Summary.Value(tag='CER', simple_value=cer)])
                    # tf.summary.scalar('CER', cer)

                    #if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    #print "Gen Probs:" + " ".join([str(np.round(res['gen_p'][i], 1)) for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])
                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_val_batch,
                                           x_org_: x_val_batch_org,
                                           is_train_: is_train
                                       })
                    test_writer.add_summary(summary, uidx)
                    test_writer.add_summary(summary_ext, uidx)
                    is_train = True

                if uidx % opt.print_freq == 0:
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_,
                                   feed_dict={
                                       x_: x_batch,
                                       x_org_: x_batch_org,
                                       is_train_: 1
                                   })

                    # if 1 in res['rec_sents'][0] or 1 in sents[0]:
                    #     pdb.set_trace()
                    if opt.char:
                        print "Original     :" + "".join(
                            [ixtoword[x] for x in sents[0] if x != 0])
                        print "Permutated   :" + "".join([
                            ixtoword[x] for x in sents_permutated[0] if x != 0
                        ])
                        if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                            print "Reconstructed:" + " ".join([
                                ixtoword[x]
                                for x in res['rec_sents_feed_y'][0] if x != 0
                            ])
                        print "Reconstructed:" + "".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    else:
                        print "Original     :" + " ".join(
                            [ixtoword[x] for x in sents[0] if x != 0])
                        print "Permutated   :" + " ".join([
                            ixtoword[x] for x in sents_permutated[0] if x != 0
                        ])
                        if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                            print "Reconstructed:" + " ".join([
                                ixtoword[x]
                                for x in res['rec_sents_feed_y'][0] if x != 0
                            ])
                        print "Reconstructed:" + " ".join([
                            ixtoword[x] for x in res['rec_sents'][0] if x != 0
                        ])

                    summary = sess.run(merged,
                                       feed_dict={
                                           x_: x_batch,
                                           x_org_: x_batch_org,
                                           is_train_: 1
                                       })
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                            tf.get_default_graph(),
                            run_meta=run_metadata,
                            tfprof_options=tf.contrib.tfprof.model_analyzer.
                            PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path)
示例#6
0
def run_model(opt, train, val, ixtoword):

    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:' +
                  str(params['Wemb'].shape) + ' opt: ' +
                  str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt)
        merged = tf.summary.merge_all()

    uidx = 0
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True,
                            graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)

    # keep all checkpoints
    saver = tf.train.Saver(max_to_keep=None)

    run_metadata = tf.RunMetadata()

    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)
                print('\nload successfully\n')

            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())
        ''' validation '''
        valid_index = np.random.choice(len(val), opt.batch_size)
        val_sents = [val[t] for t in valid_index]

        val_sents_permutated = add_noise(val_sents, opt)

        x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

        d_loss_val = sess.run(d_loss_,
                              feed_dict={
                                  x_: x_val_batch,
                                  x_org_: x_val_batch_org
                              })
        g_loss_val = sess.run(g_loss_,
                              feed_dict={
                                  x_: x_val_batch,
                                  x_org_: x_val_batch_org
                              })

        res = sess.run(res_,
                       feed_dict={
                           x_: x_val_batch,
                           x_org_: x_val_batch_org
                       })
        try:
            print("Validation d_loss %f, g_loss %f  mean_dist %f" %
                  (d_loss_val, g_loss_val, res['mean_dist']))
            print("Sent:" + u' '.join([
                ixtoword[x] for x in res['syn_sent'][0] if x != 0
            ]))  #.encode('utf-8', 'ignore').decode("utf8").strip())
            print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan']))
        except Exception as e:
            print(e)

        # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ')
        if opt.discrimination:
            print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f']))

        for i in range(268590 // 1000 +
                       1):  # generate 10k sentences # generate 268590
            valid_index = np.random.choice(len(val), opt.batch_size)
            val_sents = [val[t] for t in valid_index]
            val_sents_permutated = add_noise(val_sents, opt)
            x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
            x_val_batch_org = prepare_data_for_rnn(val_sents, opt)
            res = sess.run(res_,
                           feed_dict={
                               x_: x_val_batch,
                               x_org_: x_val_batch_org
                           })
            if i == 0:
                valid_text = res['syn_sent']
            else:
                valid_text = np.concatenate((valid_text, res['syn_sent']), 0)

        valid_text = valid_text[:268590]
        np.savetxt(PATH_TO_SAVE, valid_text, fmt='%i', delimiter=' ')
        print('saved!\n\n\n')
        exit()

        val_set = [prepare_for_bleu(s) for s in val]  #val_sents]
        bleu_prepared = [prepare_for_bleu(s) for s in res['syn_sent']]
        for i in range(len(val_set) // opt.batch_size):
            batch = val_set[i * opt.batch_size:(i + 1) * opt.batch_size]
            [bleu2s, bleu3s, bleu4s] = cal_BLEU(bleu_prepared,
                                                {0: batch})  #val_set})
            print('Val BLEU (2,3,4): ' + ' '.join(
                [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]))
def run_model(opt, train, val, test, wordtoix, ixtoword):


    try:
        params = np.load('./param_g.npz')
        if params['Wemb'].shape == (opt.n_words, opt.embed_size):
            print('Use saved embedding.')
            opt.W_emb = params['Wemb']
        else:
            print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size)))
            opt.fix_emb = False
    except IOError:
        print('No embedding file found.')
        opt.fix_emb = False

    with tf.device('/gpu:1'):
        x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len])
        is_train_ = tf.placeholder(tf.bool, name='is_train_')
        res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt)
        merged = tf.summary.merge_all()
        # opt.is_train = False
        # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt)
        # merged_val = tf.summary.merge_all()

    #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006
    #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph())


    uidx = 0
    config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1))
    #config = tf.ConfigProto(device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    np.set_printoptions(precision=3)
    np.set_printoptions(threshold=np.inf)
    saver = tf.train.Saver()



    run_metadata = tf.RunMetadata()


    with tf.Session(config = config) as sess:
        train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph)
        sess.run(tf.global_variables_initializer())
        if opt.restore:
            try:
                #pdb.set_trace()

                t_vars = tf.trainable_variables()
                #print([var.name[:-2] for var in t_vars])
                loader = restore_from_save(t_vars, sess, opt)


            except Exception as e:
                print(e)
                print("No saving session, using random initialization")
                sess.run(tf.global_variables_initializer())

        for epoch in range(opt.max_epochs):
            print("Starting epoch %d" % epoch)
            # if epoch >= 10:
            #     print("Relax embedding ")
            #     opt.fix_emb = False
            #     opt.batch_size = 2
            kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True)
            for _, train_index in kf:
                uidx += 1
                sents = [train[t] for t in train_index]

                sents_permutated = add_noise(sents, opt)

                #sents[0] = np.random.permutation(sents[0])

                if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L
                else:
                    x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L

                if opt.model != 'rnn_rnn':
                    x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L
                else:
                    x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO = False) # Batch L
                # x_print = sess.run([x_emb],feed_dict={x_: x_train} )
                # print x_print


                # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org})
                # pdb.set_trace()

                #
                if profile:
                    _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata)
                else:
                    _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})

                #pdb.set_trace()

                if uidx % opt.valid_freq == 0:
                    is_train = None
                    valid_index = np.random.choice(len(val), opt.batch_size)
                    val_sents = [val[t] for t in valid_index]

                    val_sents_permutated = add_noise(val_sents, opt)


                    if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                        x_val_batch_org = prepare_data_for_cnn(val_sents, opt)
                    else:
                        x_val_batch_org = prepare_data_for_rnn(val_sents, opt)

                    if opt.model != 'rnn_rnn':
                        x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt)
                    else:
                        x_val_batch = prepare_data_for_rnn(val_sents_permutated, opt, is_add_GO=False)

                    loss_val = sess.run(loss_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    print("Validation loss %f " % (loss_val))
                    res = sess.run(res_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    if opt.discrimination:
                        print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f']))
                    print "Val Orig :" + " ".join([ixtoword[x] for x in val_sents[0] if x != 0])
                    #print "Val Perm :" + " ".join([ixtoword[x] for x in val_sents_permutated[0] if x != 0])
                    print "Val Recon:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    val_set = [prepare_for_bleu(s) for s in val_sents]
                    [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set})
                    print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])

                    # if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn':
                    #     print "Org Probs:" + " ".join(
                    #         [ixtoword[x_val_batch_org[0][i]] + '(' + str(np.round(res['all_p'][i], 1)) + ')' for i in
                    #          range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])
                    #     print "Gen Probs:" + " ".join(
                    #         [ixtoword[res['rec_sents'][0][i]] + '(' + str(np.round(res['gen_p'][i], 1)) + ')' for i in
                    #          range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])

                    summary = sess.run(merged, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train })
                    test_writer.add_summary(summary, uidx)
                    is_train = True

                def test_input(text):
                    x_input = sent2idx(text, wordtoix, opt)
                    res = sess.run(res_, feed_dict={x_: x_input, x_org_: x_batch_org})
                    print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                if uidx % opt.print_freq == 0:
                    #pdb.set_trace()
                    print("Iteration %d: loss %f " % (uidx, loss))
                    res = sess.run(res_, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})
                    print "Original     :" + " ".join([ixtoword[x] for x in sents[0] if x != 0])
                    #print "Permutated   :" + " ".join([ixtoword[x] for x in sents_permutated[0] if x != 0])
                    if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn':
                        print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0])
                    print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0])

                    # print "Probs:" + " ".join([ixtoword[res['rec_sents'][0][i]] +'(' +str(np.round(res['all_p'][i],2))+')' for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0])

                    summary = sess.run(merged, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1})
                    train_writer.add_summary(summary, uidx)
                    # print res['x_rec'][0][0]
                    # print res['x_emb'][0][0]
                    if profile:
                        tf.contrib.tfprof.model_analyzer.print_model_analysis(
                        tf.get_default_graph(),
                        run_meta=run_metadata,
                        tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)

            saver.save(sess, opt.save_path, global_step=epoch)