def translate(trainer): vaild_index = np.random.choice(len(test_data), args.batchsize) val_sents = [test[t] for t in vaild_index] val_sents_permutated = denoise.add_noise(val_sents, opt) x_val_batch = utils.prepare_data_for_cnn(val_sents_permutated, opt.maxlen, opt.filter_shape) x_val_batch_org = utils.prepare_data_for_rnn(val_sents, opt.maxlen, opt.sent_len, opt.n_words, is_add_GO=True) mdl = generator xp = mdl.xp x_val_batch = xp.array(x_val_batch) x_val_batch_org = xp.array(x_val_batch_org) syn_sents, logits = mdl(x_val_batch, x_val_batch_org) prob = [F.softmax(l * opt.L) for l in logits] prob = F.stack(prob, 1) source_sentence = ' '.join( [source_words[int(i)] for i in x_val_batch_org[0] if i != PAD]) result_sentence = ' '.join([ source_words.get(int(i), '*NOKEY') for i in syn_sents.data[0] if i != PAD ]) prob_sentence = ' '.join( [source_words[xp.argmax(p)] for p in prob.data[:, 0]]) print('# source : ' + source_sentence) print('# sent2 : ' + result_sentence) print('# prob : ' + prob_sentence)
def run_model(opt, X): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:0'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) if opt.plot_type == 'ae': x_lat_ = ae(x_, opt) elif opt.plot_type == 'vae' or opt.plot_type == 'cyc': mu_, z_ = vae(x_, opt) x_lat_ = z_ if opt.use_z else mu_ config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.3 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) X_emb = np.zeros([len(X), opt.z_dim], dtype='float32') kf = get_minibatches_idx(len(X), opt.batch_size) t = 0 for _, index in kf: sents_b = [X[i] for i in index] x_b = prepare_data_for_cnn(sents_b, opt) x_lat = np.squeeze(sess.run(x_lat_, feed_dict={x_: x_b})) X_emb[t * opt.batch_size:(t + 1) * opt.batch_size] = x_lat if (t + 1) % 10 == 0: print('%d / %d' % (t + 1, len(kf))) t += 1 return X_emb
def convert(batch, device): if device < 0: xp = np else: xp = cuda.cupy x = [x_ for x_ in batch] x = denoise.add_noise(x, opt) x_org = [x_ for x_ in batch] x = utils.prepare_data_for_cnn(x, opt.maxlen, opt.filter_shape) x_org = utils.prepare_data_for_rnn(x_org, opt.maxlen, opt.sent_len, opt.n_words, is_add_GO=True) x = xp.array(x, dtype=np.int32) x_org = xp.array(x_org, dtype=np.int32) return {'x': x, 'x_org': x_org}
def main(): import utils maxlen = 51 filter_shape = 5 sent_len = maxlen + 2 * (filter_shape - 1) n_words = 5728 #m = auto_encoder(n_words, maxlen=maxlen) m = textGan_generator(n_words, maxlen) d = textGan_discriminator(m.embedding, n_words, maxlen=maxlen) # m = textGan(n_words, maxlen=maxlen) data = np.arange(20 * 2, dtype=np.int32).reshape(2, 20) x = utils.prepare_data_for_cnn(data, maxlen, filter_shape) x_orig = utils.prepare_data_for_rnn(data, maxlen, sent_len, n_words) syn_sents, prob = m(x, x_orig)
def update_core(self): gen_optimizer = self.get_optimizer('opt_gen') dis_optimizer = self.get_optimizer('opt_dis') xp = self.gen.xp opt = self.opt batch = self.get_iterator('main').next() batchsize = len(batch) x = denoise.add_noise(batch, self.opt) x = utils.prepare_data_for_cnn(x, opt.maxlen, opt.filter_shape) x_org = utils.prepare_data_for_rnn(batch, opt.maxlen, opt.sent_len, opt.n_words, is_add_GO=True) x = xp.array(x, dtype=np.int32) x_org = xp.array(x_org, dtype=np.int32) # generator syn_sents, prob = self.gen(x, x_org) # prob: fake data # discriminator logits_real, H_real = self.dis(x) logits_fake, H_fake = self.dis(prob, is_prob=True) # one hot vector labels_one = xp.ones((batchsize), dtype=xp.int32) # 1-dim array labels_zero = xp.zeros((batchsize), dtype=xp.int32) labels_fake = labels_zero #F.concat([labels_one, labels_zero], axis=1) labels_real = labels_one #F.concat([labels_zero, labels_one], axis=1) D_loss = F.softmax_cross_entropy(logits_real, labels_real) + \ F.softmax_cross_entropy(logits_fake, labels_fake) G_loss = compute_MMD_loss(F.squeeze(H_fake), F.squeeze(H_real)) self.gen.cleargrads() G_loss.backward() gen_optimizer.update() self.dis.cleargrads() D_loss.backward() dis_optimizer.update() H_fake.unchain_backward() H_real.unchain_backward() prob.unchain_backward() chainer.reporter.report({'loss_gen': G_loss}) chainer.reporter.report({'loss_dis': D_loss})
def run_epoch(sess, epoch, mode, print_freq=-1, train_writer=None): fetches_ = { 'loss': loss_, 'accuracy': accuracy_ } if mode == 'train': x, y, is_train = train, train_lab, 1 fetches_['train_op'] = train_op_ fetches_['summary'] = merged elif mode == 'val': assert(print_freq == -1) x, y, is_train = val, val_lab, None elif mode == 'test': assert(print_freq == -1) x, y, is_train = test, test_lab, None correct, acc_loss, acc_n = 0.0, 0.0, 0.0 local_t = 0 global_t = epoch*epoch_t # only used in train mode start_time = time.time() kf = get_minibatches_idx(len(x), opt.batch_size, shuffle=True) for _, index in kf: local_t += 1 global_t += 1 sents_b = [x[i] for i in index] sents_b_n = add_noise(sents_b, opt) y_b = [y[i] for i in index] y_b = np.array(y_b) y_b = y_b.reshape((len(y_b), 1)) x_b = prepare_data_for_cnn(sents_b_n, opt) # Batch L feed_t = {x_: x_b, y_: y_b, is_train_: is_train} fetches = sess.run(fetches_, feed_dict=feed_t) batch_size = len(index) acc_n += batch_size acc_loss += fetches['loss']*batch_size correct += fetches['accuracy']*batch_size if print_freq>0 and local_t%print_freq==0: print("%s Iter %d: loss %.4f, acc %.4f, time %.1fs" % (mode, local_t, acc_loss/acc_n, correct/acc_n, time.time()-start_time)) if mode == 'train' and train_writer != None: train_writer.add_summary(fetches['summary'], global_t) print("%s Epoch %d: loss %.4f, acc %.4f, time %.1fs" % (mode, epoch, acc_loss/acc_n, correct/acc_n, time.time()-start_time)) return acc_loss/acc_n, correct/acc_n
def main(): # global n_words # Prepare training and testing data # loadpath = "./data/three_corpus_small.p" loadpath = "./data/hotel_reviews.p" x = cPickle.load(open(loadpath, "rb")) train, val = x[0], x[1] wordtoix, ixtoword = x[2], x[3] opt = Options() opt.n_words = len(ixtoword) + 1 ixtoword[opt.n_words - 1] = 'GO_' print dict(opt) print('Total words: %d' % opt.n_words) try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) res_, loss_, train_op = auto_encoder(x_, x_org_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() # tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 # writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) # config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: # pdb.set_trace() t_vars = tf.trainable_variables() # print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] if opt.substitution == 's': sents_permutated = substitute_sent(sents, opt) elif opt.substitution == 'p': sents_permutated = permutate_sent(sents, opt) elif opt.substitution == 'a': sents_permutated = add_sent(sents, opt) elif opt.substitution == 'd': sents_permutated = delete_sent(sents, opt) else: sents_permutated = sents # sents[0] = np.random.permutation(sents[0]) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO=False) # Batch L # x_print = sess.run([x_emb],feed_dict={x_: x_train} ) # print x_print # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org}) # pdb.set_trace() _, loss = sess.run([train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }) if uidx % opt.valid_freq == 0: opt.is_train = False valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] if opt.substitution == 's': val_sents_permutated = substitute_sent(val_sents, opt) elif opt.substitution == 'p': val_sents_permutated = permutate_sent(val_sents, opt) elif opt.substitution == 'a': val_sents_permutated = add_sent(val_sents, opt) elif opt.substitution == 'd': val_sents_permutated = delete_sent(val_sents, opt) else: val_sents_permutated = sents if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn( val_sents_permutated, opt, is_add_GO=False) loss_val = sess.run(loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) print("Validation loss %f " % (loss_val)) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) print "Val Orig :" + " ".join( [ixtoword[x] for x in val_sents[0] if x != 0]) print "Val Perm :" + " ".join([ ixtoword[x] for x in val_sents_permutated[0] if x != 0 ]) print "Val Recon:" + " ".join( [ixtoword[x] for x in res['rec_sents'][0] if x != 0]) val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU( [prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set}) print 'Val BLEU (2,3,4): ' + ' '.join( [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]) summary = sess.run(merged, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) test_writer.add_summary(summary, uidx) opt.is_train = True if uidx % opt.print_freq == 0: print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict={ x_: x_batch, x_org_: x_batch_org }) print "Original :" + " ".join( [ixtoword[x] for x in sents[0] if x != 0]) print "Permutated :" + " ".join( [ixtoword[x] for x in sents_permutated[0] if x != 0]) if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': print "Reconstructed:" + " ".join([ ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0 ]) print "Reconstructed:" + " ".join( [ixtoword[x] for x in res['rec_sents'][0] if x != 0]) summary = sess.run(merged, feed_dict={ x_: x_batch, x_org_: x_batch_org }) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] saver.save(sess, opt.save_path, global_step=epoch)
def run_model(opt, train, val, test, wordtoix, ixtoword): with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt) merged = tf.summary.merge_all() summary_ext = tf.Summary() uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO=False) # Batch L # x_print = sess.run([x_emb],feed_dict={x_: x_train} ) # print x_print # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org}) # pdb.set_trace() # if profile: _, loss = sess.run( [train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) else: _, loss = sess.run([train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) #pdb.set_trace() if uidx % opt.valid_freq == 0: is_train = None valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn( val_sents_permutated, opt, is_add_GO=False) loss_val = sess.run(loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) print("Validation loss %f " % (loss_val)) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) if opt.char: print "Val Orig :" + "".join( [ixtoword[x] for x in val_sents[0] if x != 0]) print "Val Perm :" + "".join([ ixtoword[x] for x in val_sents_permutated[0] if x != 0 ]) print "Val Recon:" + "".join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ]) # print "Val Recon one hot:" + "".join([ixtoword[x] for x in res['rec_sents_one_hot'][0] if x != 0]) else: print "Val Orig :" + " ".join( [ixtoword[x] for x in val_sents[0] if x != 0]) print "Val Perm :" + " ".join([ ixtoword[x] for x in val_sents_permutated[0] if x != 0 ]) print "Val Recon:" + " ".join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ]) val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU( [prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set}) print 'Val BLEU (2,3,4): ' + ' '.join( [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]) val_set_char = [ prepare_for_cer(s, ixtoword) for s in val_sents ] cer = cal_cer([ prepare_for_cer(s, ixtoword) for s in res['rec_sents'] ], val_set_char) print 'Val CER: ' + str(round(cer, 3)) # summary_ext.Value(tag='CER', simple_value=cer) summary_ext = tf.Summary( value=[tf.Summary.Value(tag='CER', simple_value=cer)]) # tf.summary.scalar('CER', cer) #if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': #print "Gen Probs:" + " ".join([str(np.round(res['gen_p'][i], 1)) for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) summary = sess.run(merged, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) test_writer.add_summary(summary, uidx) test_writer.add_summary(summary_ext, uidx) is_train = True if uidx % opt.print_freq == 0: print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) # if 1 in res['rec_sents'][0] or 1 in sents[0]: # pdb.set_trace() if opt.char: print "Original :" + "".join( [ixtoword[x] for x in sents[0] if x != 0]) print "Permutated :" + "".join([ ixtoword[x] for x in sents_permutated[0] if x != 0 ]) if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': print "Reconstructed:" + " ".join([ ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0 ]) print "Reconstructed:" + "".join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ]) else: print "Original :" + " ".join( [ixtoword[x] for x in sents[0] if x != 0]) print "Permutated :" + " ".join([ ixtoword[x] for x in sents_permutated[0] if x != 0 ]) if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': print "Reconstructed:" + " ".join([ ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0 ]) print "Reconstructed:" + " ".join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ]) summary = sess.run(merged, feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer. PRINT_ALL_TIMING_MEMORY) saver.save(sess, opt.save_path)
def main(): #global n_words # Prepare training and testing data opt = COptions(args) opt_t = COptions(args) loadpath = (opt.data_dir + "/" + opt.data_name) print "loadpath:" + loadpath x = cPickle.load(open(loadpath, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] if opt.test: test_file = opt.data_dir + "/newdata2/test.txt" test = read_test(test_file, wordtoix) test = [ x for x in test if all( [2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)]) ] train_filtered = [ x for x in train if all([2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)]) ] val_filtered = [ x for x in val if all([2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)]) ] print("Train: %d => %d" % (len(train), len(train_filtered))) print("Val: %d => %d" % (len(val), len(val_filtered))) train, val = train_filtered, val_filtered del train_filtered, val_filtered opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params(args) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) for d in ['/gpu:0']: with tf.device(d): src_ = [ tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context) ] tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train') res_, gan_cost_g_, train_op_g = conditional_s2s( src_, tgt_, is_train_, opt, opt_t) merged = tf.summary.merge_all() uidx = 0 graph_options = tf.GraphOptions(build_cost_model=1) config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=graph_options) config.gpu_options.per_process_gpu_memory_fraction = 0.90 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() if opt.load_from_pretrain: d_vars = [ var for var in t_vars if var.name.startswith('d_') ] g_vars = [ var for var in t_vars if var.name.startswith('g_') ] l_vars = [ var for var in t_vars if var.name.startswith('l_') ] restore_from_save(d_vars, sess, opt, load_path=opt.restore_dir + "/save/" + opt.global_d) if opt.local_feature: restore_from_save(l_vars, sess, opt, load_path=opt.restore_dir + "/save/" + opt.local_d) else: loader = restore_from_save(t_vars, sess, opt, load_path=opt.save_path) except Exception as e: print 'Error: ' + str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d, loss_g = 0, 0 if opt.test: iter_num = np.int(np.floor(len(test) / opt.batch_size)) + 1 res_all = [] for i in range(iter_num): test_index = range(i * opt.batch_size, (i + 1) * opt.batch_size) sents = [val[t] for t in test_index] for idx in range(opt.n_context, opt.num_turn): src = [[ sents[i][idx - turn] for i in range(opt.batch_size) ] for turn in range(opt.n_context, 0, -1)] tgt = [sents[i][idx] for i in range(opt.batch_size)] x_batch = [ prepare_data_for_cnn(src_i, opt) for src_i in src ] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, { tgt_: y_batch, is_train_: 0 }) # do not use False res = sess.run(res_, feed_dict=feed) res_all.extend(res['syn_sent']) # bp() res_all = reshaping(res_all, opt) for idx in range(len(test) * (opt.num_turn - opt.n_context)): with open(opt.log_path + '.resp.txt', "a") as resp_f: resp_f.write(u' '.join([ ixtoword[x] for x in res_all[idx] if x != 0 and x != 2 ]).encode('utf-8').strip() + ( '\n' if idx % (opt.num_turn - opt.n_context) == 0 else '\t')) print("save to:" + opt.log_path + '.resp.txt') exit(0) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] for idx in range(opt.n_context, opt.num_turn): src = [[ sents[i][idx - turn] for i in range(opt.batch_size) ] for turn in range(opt.n_context, 0, -1)] tgt = [sents[i][idx] for i in range(opt.batch_size)] x_batch = [ prepare_data_for_cnn(src_i, opt) for src_i in src ] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, { tgt_: y_batch, is_train_: 1 }) _, loss_g = sess.run([train_op_g, gan_cost_g_], feed_dict=feed) if uidx % opt.print_freq == 0: print("Iteration %d: loss G %f" % (uidx, loss_g)) res = sess.run(res_, feed_dict=feed) if opt.global_feature: print "z loss: " + str(res['z_loss']) if "nn" in opt.agg_model: print "z pred_loss: " + str(res['z_loss_pred']) print "Source:" + u' '.join( [ixtoword[x] for s in x_batch for x in s[0] if x != 0]).encode('utf-8').strip() print "Target:" + u' '.join([ ixtoword[x] for x in y_batch[0] if x != 0 ]).encode('utf-8').strip() print "Generated:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ]).encode('utf-8').strip() print "" sys.stdout.flush() summary = sess.run(merged, feed_dict=feed) train_writer.add_summary(summary, uidx) if uidx % opt.valid_freq == 1: VALID_SIZE = 4096 valid_multiplier = np.int( np.floor(VALID_SIZE / opt.batch_size)) res_all, val_tgt_all, loss_val_g_all = [], [], [] if opt.global_feature: z_loss_all = [] for val_step in range(valid_multiplier): valid_index = np.random.choice(len(val), opt.batch_size) sents = [val[t] for t in valid_index] for idx in range(opt.n_context, opt.num_turn): src = [[ sents[i][idx - turn] for i in range(opt.batch_size) ] for turn in range(opt.n_context, 0, -1)] tgt = [ sents[i][idx] for i in range(opt.batch_size) ] val_tgt_all.extend(tgt) x_batch = [ prepare_data_for_cnn(src_i, opt) for src_i in src ] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, { tgt_: y_batch, is_train_: 0 }) # do not use False loss_val_g = sess.run([gan_cost_g_], feed_dict=feed) loss_val_g_all.append(loss_val_g) res = sess.run(res_, feed_dict=feed) res_all.extend(res['syn_sent']) if opt.global_feature: z_loss_all.append(res['z_loss']) print("Validation: loss G %f " % (np.mean(loss_val_g_all))) if opt.global_feature: print "z loss: " + str(np.mean(z_loss_all)) print "Val Source:" + u' '.join( [ixtoword[x] for s in x_batch for x in s[0] if x != 0]).encode('utf-8').strip() print "Val Target:" + u' '.join([ ixtoword[x] for x in y_batch[0] if x != 0 ]).encode('utf-8').strip() print "Val Generated:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ]).encode('utf-8').strip() print "" if opt.global_feature: with open(opt.log_path + '.z.txt', "a") as myfile: myfile.write("Iteration" + str(uidx) + "\n") myfile.write("z_loss %f" % (np.mean(z_loss_all)) + "\n") myfile.write("Val Source:" + u' '.join([ ixtoword[x] for s in x_batch for x in s[0] if x != 0 ]).encode('utf-8').strip() + "\n") myfile.write("Val Target:" + u' '.join( [ixtoword[x] for x in y_batch[0] if x != 0]).encode('utf-8').strip() + "\n") myfile.write("Val Generated:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ]).encode('utf-8').strip() + "\n") myfile.write("Z_input, Z_recon, Z_tgt") myfile.write( np.array2string(res['z'][0], formatter={ 'float_kind': lambda x: "%.2f" % x }) + "\n") myfile.write( np.array2string(res['z_hat'][0], formatter={ 'float_kind': lambda x: "%.2f" % x }) + "\n\n") myfile.write( np.array2string(res['z_tgt'][0], formatter={ 'float_kind': lambda x: "%.2f" % x }) + "\n\n") val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s, bleu2s, bleu3s, bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus=opt.is_corpus) etp_score, dist_score = cal_entropy(gen) print 'Val BLEU: ' + ' '.join([ str(round(it, 3)) for it in (bleu1s, bleu2s, bleu3s, bleu4s) ]) print 'Val Entropy: ' + ' '.join([ str(round(it, 3)) for it in (etp_score[0], etp_score[1], etp_score[2], etp_score[3]) ]) print 'Val Diversity: ' + ' '.join([ str(round(it, 3)) for it in (dist_score[0], dist_score[1], dist_score[2], dist_score[3]) ]) print 'Val Avg. length: ' + str( round( np.mean([ len([y for y in x if y != 0]) for x in res_all ]), 3)) print "" summary = sess.run(merged, feed_dict=feed) summary2 = tf.Summary(value=[ tf.Summary.Value(tag="bleu-2", simple_value=bleu2s), tf.Summary.Value(tag="etp-4", simple_value=etp_score[3]) ]) test_writer.add_summary(summary, uidx) test_writer.add_summary(summary2, uidx) if uidx % opt.save_freq == 0: saver.save(sess, opt.save_path)
def run_model(opt, train, val, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:0'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("\nStarting epoch %d\n" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: print "\rIter: %d" % uidx, uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) #sents[0] = np.random.permutation(sents[0]) x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L if x_batch.shape[0] == opt.batch_size: d_loss = 0 g_loss = 0 if profile: if uidx % opt.dis_steps == 0: _, d_loss = sess.run([dis_op, d_loss_], feed_dict={x_: x_batch},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata) if uidx % opt.gen_steps == 0: _, g_loss = sess.run([gen_op, g_loss_], feed_dict={x_: x_batch},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata) else: if uidx % opt.dis_steps == 0: _, d_loss = sess.run([dis_op, d_loss_], feed_dict={x_: x_batch}) if uidx % opt.gen_steps == 0: _, g_loss = sess.run([gen_op, g_loss_], feed_dict={x_: x_batch}) if uidx % opt.valid_freq == 0: is_train = True # print('Valid Size:', len(val)) valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) d_loss_val = sess.run(d_loss_, feed_dict={x_: x_val_batch}) g_loss_val = sess.run(g_loss_, feed_dict={x_: x_val_batch}) res = sess.run(res_, feed_dict={x_: x_val_batch}) print("Validation d_loss %f, g_loss %f mean_dist %f" % (d_loss_val, g_loss_val, res['mean_dist'])) print "Sent:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip() print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan'])) np.savetxt('./text/rec_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ') if opt.discrimination: print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['syn_sent']], {0: val_set}) print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]) summary = sess.run(merged, feed_dict={x_: x_val_batch}) test_writer.add_summary(summary, uidx) if uidx % opt.print_freq == 0: #pdb.set_trace() res = sess.run(res_, feed_dict={x_: x_batch}) median_dis = np.sqrt(np.median([((x-y)**2).sum() for x in res['real_f'] for y in res['real_f']])) print("Iteration %d: d_loss %f, g_loss %f, mean_dist %f, realdist median %f" % (uidx, d_loss, g_loss, res['mean_dist'], median_dis)) np.savetxt('./text/rec_train_words.txt', res['syn_sent'], fmt='%i', delimiter=' ') print "Sent:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip() summary = sess.run(merged, feed_dict={x_: x_batch}) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY) saver.save(sess, opt.save_path, global_step=epoch)
def main(opt): # global n_words # Prepare training and testing data data_path = opt.data_dir + "/" + opt.data_name print('loading '+data_path) x = cPickle.load(open(data_path, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] opt.n_words = len(ixtoword) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) with tf.device('/gpu:1'): x_1_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_2_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) y_ = tf.placeholder(tf.float32, shape=[opt.batch_size,]) l_temp_ = tf.placeholder(tf.float32, shape=[]) res_, loss_ ,train_op = cons_disc(x_1_, x_2_, y_, opt, l_temp_) merged = tf.summary.merge_all() uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.95 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) # feed_dict={x_: np.zeros([opt.batch_size, opt.sent_len]), x_org_: np.zeros([opt.batch_size, opt.sent_len])} if opt.restore: print('-'*20) print("Loading variables from '%s'." % opt.load_path) try: #pdb.set_trace() t_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) #tf.trainable_variables() #print([var.name[:-2] for var in t_vars] save_keys = tensors_key_in_file(opt.load_path) ss = [var for var in t_vars if var.name[:-2] in save_keys.keys()] ss = [var.name for var in ss if var.get_shape() == save_keys[var.name[:-2]]] loader = tf.train.Saver(var_list= [var for var in t_vars if var.name in ss]) loader.restore(sess, opt.load_path) print("Loaded variables:"+str(ss)) print('-'*20) except Exception as e: print 'Error: '+str(e) exit() print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) # train # if don't want to train, set max_epochs=0 for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) opt.l_temp = min(opt.l_temp * opt.l_temp_factor, opt.l_temp_max) print("Annealing temperature " + str(opt.l_temp)) kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] indice = [rand_pair(opt.task, opt.data_name) for _ in range(opt.batch_size)] if opt.task == 'L': x_1 = [sents[i][idx[0]] for i, idx in enumerate(indice)] x_2 = [sents[i][idx[1]] for i, idx in enumerate(indice)] y_batch = [(i1-i2)%2 == 0 for i1,i2 in indice] elif opt.task == 'C': batch_indice = np.concatenate([np.random.permutation(opt.batch_size/2) , range(opt.batch_size/2, opt.batch_size)]) y_batch = (range(opt.batch_size) == batch_indice) rn = np.random.choice(7,size = opt.batch_size) x_1 = [sents[i][idx[0]] for i, idx in enumerate(indice)] x_2 = [sents[batch_indice[i]][idx[1]] for i, idx in enumerate(indice)] else: # G batch_indice = np.concatenate([np.random.permutation(opt.batch_size/2) , range(opt.batch_size/2, opt.batch_size)]) y_batch = (range(opt.batch_size) == batch_indice) x_1 = [sents[i][idx[0]] for i, idx in enumerate(indice)] x_2 = [sents[batch_indice[i]][idx[1]] for i, idx in enumerate(indice)] x_1_batch = prepare_data_for_cnn(x_1, opt) # Batch L x_2_batch = prepare_data_for_cnn(x_2, opt) # Batch L feed = {x_1_: x_1_batch, x_2_: x_2_batch, y_:np.float32(y_batch),l_temp_:opt.l_temp} _, loss = sess.run([train_op, loss_], feed_dict=feed) if uidx % opt.print_freq == 1: print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict=feed) if opt.verbose: print("logits:" + str(res['logits'])) print("H1:" + str(res['H_1'][0])) print("H2:" + str(res['H_2'][0])) # print("H2:" + str(res['H_1'][0]*res['H_2'][0]-0.5)) acc = sum(np.equal(res['y_pred'],y_batch))/np.float(opt.batch_size) print("Accuracy: %f" % (acc)) print("y_mean: %f" % (np.mean(y_batch))) print("corr:" + str(res['corr'])) sys.stdout.flush() summary = sess.run(merged, feed_dict=feed) train_writer.add_summary(summary, uidx) if uidx % opt.valid_freq == 1: acc, loss_val, y_mean, corr = 0, 0, 0, 0 indice = [rand_pair(opt.task, opt.data_name) for _ in range(opt.batch_size)] for i in range(100): valid_index = np.random.choice(len(test), opt.batch_size) sents = [test[t] for t in valid_index] if opt.task == 'L': x_1 = [sents[i][idx[0]] for i, idx in enumerate(indice)] x_2 = [sents[i][idx[1]] for i, idx in enumerate(indice)] y_batch = [(i1-i2)%2 == 0 for i1,i2 in indice] elif opt.task == 'C': batch_indice = np.concatenate([np.random.permutation(opt.batch_size/2) , range(opt.batch_size/2, opt.batch_size)]) y_batch = (range(opt.batch_size) == batch_indice) rn = np.random.choice(7,size = opt.batch_size) x_1 = [sents[i][idx[0]] for i, idx in enumerate(indice)] x_2 = [sents[batch_indice[i]][idx[1]] for i, idx in enumerate(indice)] else: # G batch_indice = np.concatenate([np.random.permutation(opt.batch_size/2) , range(opt.batch_size/2, opt.batch_size)]) y_batch = (range(opt.batch_size) == batch_indice) x_1 = [sents[i][idx[0]] for i, idx in enumerate(indice)] x_2 = [sents[batch_indice[i]][idx[1]] for i, idx in enumerate(indice)] x_1_batch = prepare_data_for_cnn(x_1, opt) # Batch L x_2_batch = prepare_data_for_cnn(x_2, opt) # Batch L feed = {x_1_: x_1_batch, x_2_: x_2_batch, y_:np.float32(y_batch),l_temp_:opt.l_temp} loss_val += sess.run(loss_, feed_dict=feed) res = sess.run(res_, feed_dict=feed) acc += sum(np.equal(res['y_pred'],y_batch))/np.float(opt.batch_size) y_mean += np.mean(y_batch) corr += res['corr'] loss_val = loss_val / 100.0 acc = acc / 100.0 y_mean = y_mean / 100.0 corr = corr / 100.0 print("Validation loss %.4f " % (loss_val)) print("Validation accuracy: %.4f" % (acc)) print("Validation y_mean: %.4f" % (y_mean)) print("Validation corr: %.4f" % (corr)) print("") sys.stdout.flush() summary = sess.run(merged, feed_dict=feed) test_writer.add_summary(summary, uidx) saver.save(sess, opt.save_path, global_step=epoch) # test if opt.test: print('Testing....') iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 for i in range(iter_num): if i%100 == 0: print('Iter %i/%i'%(i, iter_num)) test_index = range(i*opt.batch_size, (i+1)*opt.batch_size) test_sents = [test[t%len(test)] for t in test_index] indice = [(0,1),(2,3),(4,5),(6,7)] for idx in indice: x_1 = [test_sents[i][idx[0]] for i in range(opt.batch_size)] x_2 = [test_sents[i][idx[1]] for i in range(opt.batch_size)] y_batch = [True for i in range(opt.batch_size)] x_1_batch = prepare_data_for_cnn(x_1, opt) # Batch L x_2_batch = prepare_data_for_cnn(x_2, opt) # Batch L feed = {x_1_: x_1_batch, x_2_: x_2_batch, y_:np.float32(y_batch), l_temp_:opt.l_temp} res = sess.run(res_, feed_dict=feed) for d in range(opt.batch_size): with open(opt.log_path + '.feature.txt', "a") as myfile: myfile.write(str(test_index[d]) + "\t" + str(idx[0]) + "\t" + " ".join([ixtoword[x] for x in x_1_batch[d] if x != 0]) + "\t" + " ".join(map(str,res['H_1'][d]))+ "\n") myfile.write(str(test_index[d]) + "\t" + str(idx[1]) + "\t" + " ".join([ixtoword[x] for x in x_2_batch[d] if x != 0]) + "\t" + " ".join(map(str,res['H_2'][d]))+ "\n")
def run_model(opt, train_unlab_x, train_lab_x, train_lab, val_unlab_x, val_lab_x, val_lab, test, test_y, wordtoix, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): alpha_ = tf.placeholder(tf.float32, shape=()) x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_lab_ = tf.placeholder(tf.int32, shape=[opt.dis_batch_size, opt.sent_len]) y_ = tf.placeholder(tf.float32, shape=[opt.dis_batch_size, 1]) dp_ratio_ = tf.placeholder(tf.float32, name='dp_ratio_') res_, dis_loss_, rec_loss_, loss_, train_op, prob_, acc_ = semi_classifier( alpha_, x_, x_org_, x_lab_, y_, dp_ratio_, opt) merged = tf.summary.merge_all() uidx = 0 max_val_accuracy = 0.0 max_test_accuracy = 0.0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) # config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train_unlab_x), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 if opt.rec_alpha > 0 and uidx > opt.pretrain_step and uidx % opt.rec_decay_freq == 0: opt.rec_alpha -= 0.01 print "alpha: " + str(opt.rec_alpha) sents = [train_unlab_x[t] for t in train_index] lab_index = np.random.choice(len(train_lab), opt.dis_batch_size, replace=False) lab_sents = [train_lab_x[t] for t in lab_index] batch_lab = [train_lab[t] for t in lab_index] batch_lab = np.array(batch_lab) batch_lab = batch_lab.reshape((len(batch_lab), 1)) x_batch_lab = prepare_data_for_cnn(lab_sents, opt) sents_permutated = add_noise(sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO=False) # Batch L _, dis_loss, rec_loss, loss, acc = sess.run( [train_op, dis_loss_, rec_loss_, loss_, acc_], feed_dict={ alpha_: opt.rec_alpha, x_: x_batch, x_org_: x_batch_org, x_lab_: x_batch_lab, y_: batch_lab, dp_ratio_: opt.dropout_ratio }) summary = sess.run(merged, feed_dict={ alpha_: opt.rec_alpha, x_: x_batch, x_org_: x_batch_org, x_lab_: x_batch_lab, y_: batch_lab, dp_ratio_: opt.dropout_ratio }) train_writer.add_summary(summary, uidx) if uidx % opt.print_freq == 0: print( "Iteration %d: dis_loss %f, rec_loss %f, loss %f, acc %f " % (uidx, dis_loss, rec_loss, loss, acc)) if uidx % opt.valid_freq == 0: #print("Iteration %d: dis_loss %f, rec_loss %f, loss %f " % (uidx, dis_loss, rec_loss, loss)) valid_index = np.random.choice(len(val_unlab_x), opt.batch_size) val_sents = [val_unlab_x[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn( val_sents_permutated, opt, is_add_GO=False) rec_loss_val = sess.run(rec_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, dp_ratio_: 1.0 }) print("Validation rec loss %f " % rec_loss_val) kf_val = get_minibatches_idx(len(val_lab_x), opt.dis_batch_size, shuffle=False) prob_val = [] for _, val_ind in kf_val: val_sents = [val_lab_x[t] for t in val_ind] x_val_dis = prepare_data_for_cnn(val_sents, opt) val_y = np.array([val_lab[t] for t in val_ind]).reshape( (opt.dis_batch_size, 1)) val_prob = sess.run(prob_, feed_dict={ x_lab_: x_val_dis, dp_ratio_: 1.0 }) for x in val_prob: prob_val.append(x) ##### DON'T UNDERSTAND :error val_index # probs = [] # val_truth = [] # for i in range(len(val_lab)): # val_truth.append(val_lab[i]) # if type(val_index[i]) != int: # temp = [] # for j in val_index[i]: # temp.append(prob_val[j]) # aver = sum(temp) * 1.0 / len(temp) # probs.append(aver) # else: # probs.append(prob_val[val_index[i]]) probs = [] val_truth = [] for i in range(len(prob_val)): val_truth.append(val_lab[i]) probs.append(prob_val[i]) count = 0.0 for i in range(len(probs)): p = probs[i] if p > 0.5: if val_truth[i] == 1: count += 1.0 else: if val_truth[i] == 0: count += 1.0 val_accuracy = count * 1.0 / len(probs) print("Validation accuracy %f " % val_accuracy) summary = sess.run(merged, feed_dict={ alpha_: opt.rec_alpha, x_: x_val_batch, x_org_: x_val_batch_org, x_lab_: x_val_dis, y_: val_y, dp_ratio_: 1.0 }) test_writer.add_summary(summary, uidx) if val_accuracy >= max_val_accuracy: max_val_accuracy = val_accuracy kf_test = get_minibatches_idx(len(test), opt.dis_batch_size, shuffle=False) prob_test = [] for _, test_ind in kf_test: test_sents = [test[t] for t in test_ind] x_test_batch = prepare_data_for_cnn( test_sents, opt) test_prob = sess.run(prob_, feed_dict={ x_lab_: x_test_batch, dp_ratio_: 1.0 }) for x in test_prob: prob_test.append(x) probs = [] test_truth = [] for i in range(len(prob_test)): test_truth.append(test_y[i]) probs.append(prob_test[i]) # probs = [] # test_truth = [] # for i in range(len(test_y)): # test_truth.append(test_y[i]) # if type(test_index[i]) != int: # temp = [prob_test[j] for j in test_index[i]] # aver = sum(temp) * 1.0 / len(temp) # probs.append(aver) # else: # probs.append(prob_test[test_index[i]]) count = 0.0 for i in range(len(probs)): p = probs[i] if p > 0.5: if test_truth[i] == 1.0: count += 1.0 else: if test_truth[i] == 0.0: count += 1.0 test_accuracy = count * 1.0 / len(probs) print("Test accuracy %f " % test_accuracy) max_test_accuracy = test_accuracy def test_input(text): x_input = sent2idx(text, wordtoix, opt) res = sess.run(res_, feed_dict={ x_: x_input, x_org_: x_batch_org }) print "Reconstructed:" + " ".join( [ixtoword[x] for x in res['rec_sents'][0] if x != 0]) # res = sess.run(res_, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_: 1}) # print "Original :" + " ".join([ixtoword[x] for x in sents[0] if x != 0]) # # print "Permutated :" + " ".join([ixtoword[x] for x in sents_permutated[0] if x != 0]) # if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': # print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0]) # print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0]) # print "Probs:" + " ".join([ixtoword[res['rec_sents'][0][i]] +'(' +str(np.round(res['all_p'][i],2))+')' for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) print(opt.rec_alpha) print("Epoch %d: Max Valid accuracy %f" % (epoch, max_val_accuracy)) print("Epoch %d: Max Test accuracy %f" % (epoch, max_test_accuracy)) saver.save(sess, opt.save_path, global_step=epoch)
def run_model(opt, train, val, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt) merged = tf.summary.merge_all() uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.per_process_gpu_memory_fraction = 0.8 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) # keep all checkpoints saver = tf.train.Saver(max_to_keep=None) run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) print('\nload successfully\n') except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) ''' validation ''' valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) x_val_batch_org = prepare_data_for_rnn(val_sents, opt) d_loss_val = sess.run(d_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) g_loss_val = sess.run(g_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) try: print("Validation d_loss %f, g_loss %f mean_dist %f" % (d_loss_val, g_loss_val, res['mean_dist'])) print("Sent:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ])) #.encode('utf-8', 'ignore').decode("utf8").strip()) print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan'])) except Exception as e: print(e) # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ') if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) for i in range(268590 // 1000 + 1): # generate 10k sentences # generate 268590 valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) x_val_batch_org = prepare_data_for_rnn(val_sents, opt) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) if i == 0: valid_text = res['syn_sent'] else: valid_text = np.concatenate((valid_text, res['syn_sent']), 0) valid_text = valid_text[:268590] np.savetxt(PATH_TO_SAVE, valid_text, fmt='%i', delimiter=' ') print('saved!\n\n\n') exit() val_set = [prepare_for_bleu(s) for s in val] #val_sents] bleu_prepared = [prepare_for_bleu(s) for s in res['syn_sent']] for i in range(len(val_set) // opt.batch_size): batch = val_set[i * opt.batch_size:(i + 1) * opt.batch_size] [bleu2s, bleu3s, bleu4s] = cal_BLEU(bleu_prepared, {0: batch}) #val_set}) print('Val BLEU (2,3,4): ' + ' '.join( [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]))
def run_model(opt, train, val, test, wordtoix, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) #sents[0] = np.random.permutation(sents[0]) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO = False) # Batch L # x_print = sess.run([x_emb],feed_dict={x_: x_train} ) # print x_print # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org}) # pdb.set_trace() # if profile: _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata) else: _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1}) #pdb.set_trace() if uidx % opt.valid_freq == 0: is_train = None valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn(val_sents_permutated, opt, is_add_GO=False) loss_val = sess.run(loss_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train }) print("Validation loss %f " % (loss_val)) res = sess.run(res_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train }) if opt.discrimination: print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) print "Val Orig :" + " ".join([ixtoword[x] for x in val_sents[0] if x != 0]) #print "Val Perm :" + " ".join([ixtoword[x] for x in val_sents_permutated[0] if x != 0]) print "Val Recon:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0]) val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set}) print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]) # if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': # print "Org Probs:" + " ".join( # [ixtoword[x_val_batch_org[0][i]] + '(' + str(np.round(res['all_p'][i], 1)) + ')' for i in # range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) # print "Gen Probs:" + " ".join( # [ixtoword[res['rec_sents'][0][i]] + '(' + str(np.round(res['gen_p'][i], 1)) + ')' for i in # range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) summary = sess.run(merged, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train }) test_writer.add_summary(summary, uidx) is_train = True def test_input(text): x_input = sent2idx(text, wordtoix, opt) res = sess.run(res_, feed_dict={x_: x_input, x_org_: x_batch_org}) print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0]) if uidx % opt.print_freq == 0: #pdb.set_trace() print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1}) print "Original :" + " ".join([ixtoword[x] for x in sents[0] if x != 0]) #print "Permutated :" + " ".join([ixtoword[x] for x in sents_permutated[0] if x != 0]) if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0]) print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0]) # print "Probs:" + " ".join([ixtoword[res['rec_sents'][0][i]] +'(' +str(np.round(res['all_p'][i],2))+')' for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) summary = sess.run(merged, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1}) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY) saver.save(sess, opt.save_path, global_step=epoch)
def main(): #global n_words # Prepare training and testing data opt = COptions(args) opt_t = COptions(args) # opt_t.n_hid = opt.n_z loadpath = (opt.data_dir + "/" + opt.data_name) #if opt.not_philly else '/hdfs/msrlabs/xiag/pt-data/cons/data_cleaned/twitter_small.p' print "loadpath:" + loadpath x = cPickle.load(open(loadpath, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] if opt.test: test_file = opt.data_dir + opt.test_file test = read_test(test_file, wordtoix) # test = [ x for x in test if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])] # train_filtered = [ x for x in train if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])] # val_filtered = [ x for x in val if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])] # print ("Train: %d => %d" % (len(train), len(train_filtered))) # print ("Val: %d => %d" % (len(val), len(val_filtered))) # train, val = train_filtered, val_filtered # del train_filtered, val_filtered opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params(args) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) # print dict(opt) # if opt.model == 'cnn_rnn': # opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 # opt_t.update_params(args) # print dict(opt_t) #for d in ['/gpu:0', '/gpu:1', '/gpu:2', '/gpu:3']: for d in ['/gpu:0']: with tf.device(d): src_ = [tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context)] tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) z_ = tf.placeholder(tf.float32, shape=[opt_t.batch_size , opt.n_z * (2 if opt.local_feature else 1)]) is_train_ = tf.placeholder(tf.bool, name = 'is_train') res_1_ = get_features(src_, tgt_, is_train_, opt, opt_t) res_2_ = generate_resp(src_, tgt_, z_, is_train_, opt, opt_t) merged = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 graph_options=tf.GraphOptions(build_cost_model=1) #config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options) # config.gpu_options.per_process_gpu_memory_fraction = 0.70 #config = tf.ConfigProto(device_count={'GPU':0}) #config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #t_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) #tf.trainable_variables() # if opt.load_from_pretrain: # d_vars = [var for var in t_vars if var.name.startswith('d_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # l_vars = [var for var in t_vars if var.name.startswith('l_')] # #restore_from_save(g_vars, sess, opt, prefix = 'g_', load_path=opt.restore_dir + "/save/generator2") # restore_from_save(d_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.global_d) # if opt.local_feature: # restore_from_save(l_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.local_d) # else: loader = restore_from_save(t_vars, sess, opt, load_path = opt.save_path) except Exception as e: print 'Error: '+str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d , loss_g = 0, 0 if opt.test: iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 res_all = [] val_tgt_all =[] for i in range(iter_num): test_index = range(i * opt.batch_size,(i+1) * opt.batch_size) sents = [test[t%len(test)] for t in test_index] for idx in range(opt.n_context,opt.num_turn): src = [[sents[i][idx-turn] for i in range(opt.batch_size)] for turn in range(opt.n_context,0,-1)] tgt = [sents[i][idx] for i in range(opt.batch_size)] val_tgt_all.extend(tgt) if opt.feed_generated and idx!= opt.n_context: src[-1] = [[x for x in p if x!=0] for p in res_all[-opt.batch_size:]] x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) # do not use False res_1 = sess.run(res_1_, feed_dict=feed) z_all = np.array(res_1['z']) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, z_: z_all, is_train_: 0}) # do not use False res_2 = sess.run(res_2_, feed_dict=feed) res_all.extend(res_2['syn_sent']) # bp() val_tgt_all = reshaping(val_tgt_all, opt) res_all = reshaping(res_all, opt) save_path = opt.log_path + '.resp.txt' if os.path.exists(save_path): os.remove(save_path) for idx in range(len(test)*(opt.num_turn-opt.n_context)): with open(save_path, "a") as resp_f: resp_f.write(u' '.join([ixtoword[x] for x in res_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') ) print ("save to:" + save_path) if opt.verbose: save_path = opt.log_path + '.tgt.txt' if os.path.exists(save_path): os.remove(save_path) for idx in range(len(test)*(opt.num_turn-opt.n_context)): with open(save_path, "a") as tgt_f: tgt_f.write(u' '.join([ixtoword[x] for x in val_tgt_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') ) print ("save to:" + save_path) val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus) etp_score, dist_score = cal_entropy(gen) # print save_path print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)]) # print 'Val Rouge: ' + ' '.join([str(round(it,3)) for it in (rouge1,rouge2,rouge3,rouge4)]) print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])]) print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])]) # print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])]) print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) if opt.embedding_score: with open("../../ssd0/consistent_dialog/data/GoogleNews-vectors-negative300.bin.p", 'rb') as pfile: embedding = cPickle.load(pfile) rel_score = cal_relevance(gen, val_set, embedding) print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])]) if not opt.global_feature or opt.bit == None: exit(0) if opt.test: iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 for int_idx in range(opt.int_num): res_all = [] z1,z2,z3 = [],[],[] val_tgt_all =[] for i in range(iter_num): test_index = range(i * opt.batch_size,(i+1) * opt.batch_size) sents = [test[t%len(test)] for t in test_index] for idx in range(opt.n_context,opt.num_turn): src = [[sents[i][idx-turn] for i in range(opt.batch_size)] for turn in range(opt.n_context,0,-1)] tgt = [sents[i][idx] for i in range(opt.batch_size)] val_tgt_all.extend(tgt) if opt.feed_generated and idx!= opt.n_context: src[-1] = [[x for x in p if x!=0] for p in res_all[-opt.batch_size:]] x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) # do not use False res_1 = sess.run(res_1_, feed_dict=feed) z_all = np.array(res_1['z']) z_all[:,opt.bit] = np.array([1.0/np.float(opt.int_num-1) * int_idx for _ in range(opt.batch_size)]) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, z_: z_all, is_train_: 0}) # do not use False res_2 = sess.run(res_2_, feed_dict=feed) res_all.extend(res_2['syn_sent']) z1.extend(res_1['z']) z2.extend(z_all) z3.extend(res_2['z_hat']) # bp() val_tgt_all = reshaping(val_tgt_all, opt) res_all = reshaping(res_all, opt) z1 = reshaping(z1, opt) z2 = reshaping(z2, opt) z3 = reshaping(z3, opt) save_path = opt.log_path + 'bit' + str(opt.bit) + '.'+ str(1.0/np.float(opt.int_num-1) * int_idx) +'.int.txt' if os.path.exists(save_path): os.remove(save_path) for idx in range(len(test)*(opt.num_turn-opt.n_context)): with open(save_path, "a") as resp_f: resp_f.write(u' '.join([ixtoword[x] for x in res_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') ) print ("save to:" + save_path) save_path_z = opt.log_path + 'bit' + str(opt.bit) + '.'+ str(1.0/np.float(opt.int_num-1) * int_idx) +'.z.txt' if os.path.exists(save_path_z): os.remove(save_path_z) for idx in range(len(test)*(opt.num_turn-opt.n_context)): with open(save_path_z, "a") as myfile: #ary = np.array([z1[idx][opt.bit], z2[idx][opt.bit], z3[idx][opt.bit]]) #myfile.write(np.array2string(ary, formatter={'float_kind':lambda x: "%.2f" % x}) + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t')) myfile.write(str(z3[idx][opt.bit]) + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t')) val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus) etp_score, dist_score = cal_entropy(gen) print save_path print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)]) # print 'Val Rouge: ' + ' '.join([str(round(it,3)) for it in (rouge1,rouge2,rouge3,rouge4)]) print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])]) print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])]) # print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])]) print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3))
def main(): opt = Options(args) opt_t = Options(args) opt_t.n_hid = opt.n_z loadpath = (opt.data_dir + "/" + opt.data_name) print "loadpath:" + loadpath x = cPickle.load(open(loadpath, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] opt.n_words = len(ixtoword) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) print dict(opt) if opt.model == 'cnn_rnn': opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params(args) print dict(opt_t) print('Total words: %d' % opt.n_words) for d in ['/gpu:0']: with tf.device(d): src_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) z_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.n_z]) res_, gan_cost_g_, train_op_g = conditional_s2s(src_, tgt_, z_, opt, opt_t) merged = tf.summary.merge_all() uidx = 0 graph_options=tf.GraphOptions(build_cost_model=1) config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options) config.gpu_options.per_process_gpu_memory_fraction = 0.90 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() if opt.load_from_pretrain: d_vars = [var for var in t_vars if var.name.startswith('d_')] g_vars = [var for var in t_vars if var.name.startswith('g_')] restore_from_save(g_vars, sess, opt, opt.restore_dir + "/save/generator") restore_from_save(d_vars, sess, opt, opt.restore_dir + "/save/discriminator") else: loader = restore_from_save(t_vars, sess, opt) except Exception as e: print 'Error: '+str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d , loss_g = 0, 0 for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 indice = [(x,x+1) for x in range(7)] sents = [train[t] for t in train_index] for idx in indice: src = [sents[i][idx[0]] for i in range(opt.batch_size)] tgt = [sents[i][idx[1]] for i in range(opt.batch_size)] src_permutated = src x_batch = prepare_data_for_cnn(src_permutated, opt) y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) if opt.z_prior == 'g': z_batch = np.random.normal(0,1,(opt.n_hid, opt.n_z)).astype('float32') else: z_batch = np.random.uniform(-1,1,(opt.batch_size, opt.n_z)).astype('float32') feed = {src_: x_batch, tgt_: y_batch, z_:z_batch} _, loss_g = sess.run([train_op_g, gan_cost_g_],feed_dict=feed) if uidx%opt.print_freq == 0: print("Iteration %d: loss G %f" %(uidx, loss_g)) res = sess.run(res_, feed_dict=feed) print "Source:" + u' '.join([ixtoword[x] for x in x_batch[0] if x != 0]).encode('utf-8').strip() print "Target:" + u' '.join([ixtoword[x] for x in y_batch[0] if x != 0]).encode('utf-8').strip() print "Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip() print "" sys.stdout.flush() summary = sess.run(merged, feed_dict=feed) train_writer.add_summary(summary, uidx) if uidx%opt.valid_freq == 1: VALID_SIZE = 4096 valid_multiplier = np.int(np.floor(VALID_SIZE/opt.batch_size)) res_all, val_tgt_all, loss_val_d_all, loss_val_g_all = [], [], [], [] if opt.global_feature: z_loss_all = [] for val_step in range(valid_multiplier): valid_index = np.random.choice(len(test), opt.batch_size) indice = [(x,x+1) for x in range(7)] val_sents = [test[t] for t in valid_index] for idx in indice: val_src = [val_sents[i][idx[0]] for i in range(opt.batch_size)] val_tgt = [val_sents[i][idx[1]] for i in range(opt.batch_size)] val_tgt_all.extend(val_tgt) val_src_permutated = val_src x_val_batch = prepare_data_for_cnn(val_src, opt) y_val_batch = prepare_data_for_rnn(val_src, opt_t, is_add_GO = False) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(val_src, opt_t) if opt.z_prior == 'g': z_val_batch = np.random.normal(0,1,(opt.batch_size, opt.n_z)).astype('float32') else: z_val_batch = np.random.uniform(-1,1,(opt.batch_size, opt.n_z)).astype('float32') feed_val = {src_: x_val_batch, tgt_: y_val_batch, z_:z_val_batch} loss_val_g = sess.run([gan_cost_g_], feed_dict=feed_val) loss_val_g_all.append(loss_val_g) res = sess.run(res_, feed_dict=feed_val) res_all.extend(res['syn_sent']) if opt.global_feature: z_loss_all.append(res['z_loss']) print("Validation: loss G %f " %( np.mean(loss_val_g_all))) print "Val Source:" + u' '.join([ixtoword[x] for x in val_src[0] if x != 0]).encode('utf-8').strip() print "Val Target :" + u' '.join([ixtoword[x] for x in val_tgt[0] if x != 0]).encode('utf-8').strip() print "Val Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip() print "" if opt.global_feature: with open(opt.log_path + '.z.txt', "a") as myfile: myfile.write("Iteration" + str(uidx) + "\n") myfile.write("z_loss %f" %(np.mean(z_loss_all))+ "\n") myfile.write("Val Source:" + u' '.join([ixtoword[x] for x in val_src[0] if x != 0]).encode('utf-8').strip()+ "\n") myfile.write("Val Target :" + u' '.join([ixtoword[x] for x in val_tgt[0] if x != 0]).encode('utf-8').strip()+ "\n") myfile.write("Val Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()+ "\n") myfile.write(np.array2string(res['z'][0], formatter={'float_kind':lambda x: "%.2f" % x})+ "\n") myfile.write(np.array2string(res['z_hat'][0], formatter={'float_kind':lambda x: "%.2f" % x})+ "\n\n") val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus) etp_score, dist_score = cal_entropy(gen) print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)]) print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])]) print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])]) print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) print "" summary = sess.run(merged, feed_dict=feed_val) summary2 = tf.Summary(value=[tf.Summary.Value(tag="bleu-2", simple_value=bleu2s),tf.Summary.Value(tag="etp-4", simple_value=etp_score[3])]) test_writer.add_summary(summary, uidx) test_writer.add_summary(summary2, uidx) if uidx%opt.save_freq == 0: saver.save(sess, opt.save_path)
def run_epoch(sess, epoch, mode, print_freq=-1, display_sent=-1, train_writer=None): fetches_ = {'loss': loss_} if mode == 'train': x, is_train = train, 1 fetches_['train_op'] = train_op_ fetches_['summary'] = merged elif mode == 'val': assert (print_freq == -1) x, is_train = val, None elif mode == 'test': assert (print_freq == -1) x, is_train = test, None acc_loss, acc_n = 0.0, 0.0 local_t = 0 global_t = epoch * epoch_t # only used in train mode start_time = time.time() kf = get_minibatches_idx(len(x), opt.batch_size, shuffle=True) for _, index in kf: local_t += 1 global_t += 1 sents_b = [x[i] for i in index] sents_b_n = add_noise(sents_b, opt) x_b_org = prepare_data_for_rnn(sents_b, opt) # Batch L x_b = prepare_data_for_cnn(sents_b_n, opt) # Batch L feed_t = {x_: x_b, x_org_: x_b_org, is_train_: is_train} fetches = sess.run(fetches_, feed_dict=feed_t) batch_size = len(index) acc_n += batch_size acc_loss += fetches['loss'] * batch_size if print_freq > 0 and local_t % print_freq == 0: print("%s Iter %d: loss %.4f, time %.1fs" % (mode, local_t, acc_loss / acc_n, time.time() - start_time)) sys.stdout.flush() if mode == 'train' and train_writer != None: train_writer.add_summary(fetches['summary'], global_t) if display_sent > 0: index_d = np.random.choice(len(x), opt.batch_size, replace=False) sents_d = [x[i] for i in index_d] sents_d_n = add_noise(sents_d, opt) x_d_org = prepare_data_for_rnn(sents_d, opt) # Batch L x_d = prepare_data_for_cnn(sents_d_n, opt) # Batch L res = sess.run(res_, feed_dict={ x_: x_d, x_org_: x_d_org, is_train_: is_train }) for i in range(display_sent): print( "%s Org: " % mode + " ".join([ ixtoword[ix] for ix in sents_d[i] if ix != 0 and ix != 2 ])) if mode == 'train': print( "%s Rec(feedy): " % mode + " ".join([ ixtoword[ix] for ix in res['rec_sents_feed_y'][i] if ix != 0 and ix != 2 ])) print( "%s Rec: " % mode + " ".join([ ixtoword[ix] for ix in res['rec_sents'][i] if ix != 0 and ix != 2 ])) print("%s Epoch %d: loss %.4f, time %.1fs" % (mode, epoch, acc_loss / acc_n, time.time() - start_time)) return acc_loss / acc_n
def run_epoch(sess, epoch, mode, print_freq=-1, display_sent=-1, train_writer=None): fetches_ = {'loss': loss_, 'rec_loss': rec_loss_, 'kl_loss': kl_loss_} if mode == 'train': x, is_train = train, 1 fetches_['train_op'] = train_op_ fetches_['summary'] = merged elif mode == 'val': assert (print_freq == -1) x, is_train = val, None elif mode == 'test': assert (print_freq == -1) x, is_train = test, None acc_loss, acc_rec, acc_kl, acc_n = 0.0, 0.0, 0.0, 0.0 local_t = 0 global_t = epoch * epoch_t # only used in train mode start_time = time.time() kf = get_minibatches_idx(len(x), opt.batch_size, shuffle=True) for _, index in kf: local_t += 1 global_t_cyc = global_t % cycle_t lr_t = 0.5 * opt.lr * ( 1 + np.cos(float(global_t_cyc) / cycle_t * np.pi)) global_t += 1 if mode == 'train': if opt.vae_anneal: beta_t = opt.max_beta * np.minimum( (global_t_cyc + 1.) / full_kl_step, 1.) else: beta_t = opt.max_beta else: beta_t = opt.max_beta sents_b = [x[i] for i in index] sents_b_n = add_noise(sents_b, opt) x_b_org = prepare_data_for_rnn(sents_b, opt) # Batch L x_b = prepare_data_for_cnn(sents_b_n, opt) # Batch L feed_t = { beta_: beta_t, x_: x_b, x_org_: x_b_org, is_train_: is_train, lr_: lr_t } fetches = sess.run(fetches_, feed_dict=feed_t) batch_size = len(index) acc_n += batch_size acc_loss += fetches['loss'] * batch_size acc_rec += fetches['rec_loss'] * batch_size acc_kl += fetches['kl_loss'] * batch_size if print_freq > 0 and local_t % print_freq == 0: print( "%s Iter %d: loss %.4f, rec %.4f, kl %.4f, beta %.4f, lr %.4fe-4, time %.1fs" % (mode, local_t, acc_loss / acc_n, acc_rec / acc_n, acc_kl / acc_n, beta_t, lr_t * 1e4, time.time() - start_time)) sys.stdout.flush() if mode == 'train' and train_writer != None: train_writer.add_summary(fetches['summary'], global_t) if display_sent > 0: index_d = np.random.choice(len(x), opt.batch_size, replace=False) sents_d = [x[i] for i in index_d] sents_d_n = add_noise(sents_d, opt) x_d_org = prepare_data_for_rnn(sents_d, opt) # Batch L x_d = prepare_data_for_cnn(sents_d_n, opt) # Batch L res = sess.run(res_, feed_dict={ beta_: beta_t, x_: x_d, x_org_: x_d_org, is_train_: is_train }) for i in range(display_sent): print( "%s Org: " % mode + " ".join([ ixtoword[ix] for ix in sents_d[i] if ix != 0 and ix != 2 ])) if mode == 'train': print( "%s Rec(feedy): " % mode + " ".join([ ixtoword[ix] for ix in res['rec_sents_feed_y'][i] if ix != 0 and ix != 2 ])) print( "%s Rec: " % mode + " ".join([ ixtoword[ix] for ix in res['rec_sents'][i] if ix != 0 and ix != 2 ])) print( "%s Epoch %d: loss %.4f, rec %.4f, kl %.4f, beta %.4f, time %.1fs" % (mode, epoch, acc_loss / acc_n, acc_rec / acc_n, acc_kl / acc_n, beta_t, time.time() - start_time)) return acc_loss / acc_n, acc_rec / acc_n, acc_kl / acc_n
def run_model(opt, train, val, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) # is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.per_process_gpu_memory_fraction = 0.8 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) print('\nload successfully\n') except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) # for i in range(34): # valid_index = np.random.choice( # len(val), opt.batch_size) # val_sents = [val[t] for t in valid_index] # val_sents_permutated = add_noise(val_sents, opt) # x_val_batch = prepare_data_for_cnn( # val_sents_permutated, opt) # x_val_batch_org = prepare_data_for_rnn(val_sents, opt) # res = sess.run(res_, feed_dict={ # x_: x_val_batch, x_org_: x_val_batch_org}) # if i == 0: # valid_text = res['syn_sent'] # else: # valid_text = np.concatenate( # (valid_text, res['syn_sent']), 0) # np.savetxt('./text_news/vae_words.txt', valid_text, fmt='%i', delimiter=' ') # pdb.set_trace() for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) #sents[0] = np.random.permutation(sents[0]) x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L x_batch_org = prepare_data_for_rnn(sents, opt) d_loss = 0 g_loss = 0 if profile: if uidx % opt.dis_steps == 0: _, d_loss = sess.run( [dis_op, d_loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) if uidx % opt.gen_steps == 0: _, g_loss = sess.run( [gen_op, g_loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) else: if uidx % opt.dis_steps == 0: _, d_loss = sess.run([dis_op, d_loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }) if uidx % opt.gen_steps == 0: _, g_loss = sess.run([gen_op, g_loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }) ''' validation ''' if uidx % opt.valid_freq == 0: valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) x_val_batch_org = prepare_data_for_rnn(val_sents, opt) d_loss_val = sess.run(d_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) g_loss_val = sess.run(g_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) print("Validation d_loss %f, g_loss %f mean_dist %f" % (d_loss_val, g_loss_val, res['mean_dist'])) print("Sent:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ])) #.encode('utf-8', 'ignore').decode("utf8").strip()) print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan'])) # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ') if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) for i in range(4): valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) x_val_batch_org = prepare_data_for_rnn(val_sents, opt) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) if i == 0: valid_text = res['syn_sent'] else: valid_text = np.concatenate( (valid_text, res['syn_sent']), 0) np.savetxt('./text_news/syn_val_words.txt', valid_text, fmt='%i', delimiter=' ') val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU( [prepare_for_bleu(s) for s in res['syn_sent']], {0: val_set}) print('Val BLEU (2,3,4): ' + ' '.join( [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])) summary = sess.run(merged, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) test_writer.add_summary(summary, uidx)
def main(): # Prepare training and testing data loadpath = "./data/" src_file = loadpath + "Pairs2M.src.num" tgt_file = loadpath + "Pairs2M.tgt.num" dic_file = loadpath + "Pairs2M.reddit.dic" opt = Options() opt_t = Options() train, val, test, wordtoix, ixtoword = read_pair_data_full( src_file, tgt_file, dic_file, max_num=opt.data_size, p_f=loadpath + 'demo.p') train = [ x for x in train if 2 < len(x[1]) < opt.maxlen - 4 and 2 < len(x[0]) < opt_t.maxlen - 4 ] val = [ x for x in val if 2 < len(x[1]) < opt.maxlen - 4 and 2 < len(x[0]) < opt_t.maxlen - 4 ] if TEST_FLAG: test = [test] opt.test_freq = 1 opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) print dict(opt) if opt.model == 'cnn_rnn': opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params() print dict(opt_t) print('Total words: %d' % opt.n_words) # load w2v if os.path.exists(opt.embedding_path_lime): with open(opt.embedding_path_lime, 'rb') as pfile: embedding = cPickle.load(pfile) else: w2v = gensim.models.KeyedVectors.load_word2vec_format( opt.embedding_path, binary=True) embedding = { i: copy.deepcopy(w2v[ixtoword[i]]) for i in range(opt.n_words) if ixtoword[i] in w2v } with open(opt.embedding_path_lime, 'wb') as pfile: cPickle.dump(embedding, pfile, protocol=cPickle.HIGHEST_PROTOCOL) for d in ['/gpu:0']: with tf.device(d): src_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) res_, gan_cost_d_, train_op_d, gan_cost_g_, train_op_g = dialog_gan( src_, tgt_, opt, opt_t) merged = tf.summary.merge_all() uidx = 0 graph_options = tf.GraphOptions(build_cost_model=1) config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=graph_options) config.gpu_options.per_process_gpu_memory_fraction = 0.95 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES) #tf.trainable_variables() if opt.load_from_ae: save_keys = tensors_key_in_file( opt.load_path) #t_var g_W:0 key: W ss = [ var for var in t_vars if var.name[2:][:-2] in save_keys.keys() ] ss = [ var.name[2:] for var in ss if var.get_shape() == save_keys[var.name[2:][:-2]] ] cc = { var.name[2:][:-2]: var for var in t_vars if var.name[2:] in ss } loader = tf.train.Saver(var_list=cc) loader.restore(sess, opt.load_path) print("Loading variables from '%s'." % opt.load_path) print( "Loaded variables:" + " ".join( [var.name for var in t_vars if var.name[2:] in ss])) else: save_keys = tensors_key_in_file(opt.load_path) ss = [ var for var in t_vars if var.name[:-2] in save_keys.keys() ] ss = [ var.name for var in ss if var.get_shape() == save_keys[var.name[:-2]] ] loader = tf.train.Saver( var_list=[var for var in t_vars if var.name in ss]) loader.restore(sess, opt.load_path) print("Loading variables from '%s'." % opt.load_path) print("Loaded variables:" + str(ss)) # load reverse model try: save_keys = tensors_key_in_file('./save/rev_model') ss = [ var for var in t_vars if var.name[:-2] in save_keys.keys() and 'g_rev_' in var.name ] ss = [ var.name for var in ss if var.get_shape() == save_keys[var.name[:-2]] ] loader = tf.train.Saver( var_list=[var for var in t_vars if var.name in ss]) loader.restore(sess, './save/rev_model') print( "Loading reverse variables from ./save/rev_model") print("Loaded variables:" + str(ss)) except Exception as e: print("No reverse model loaded") except Exception as e: print 'Error: ' + str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d, loss_g = 0, 0 for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 if uidx % opt.test_freq == 1: iter_num = np.int(np.floor(len(test) / opt.batch_size)) + 1 res_all, test_tgt_all = [], [] for i in range(iter_num): test_index = range(i * opt.batch_size, (i + 1) * opt.batch_size) test_tgt, test_src = zip( *[test[t % len(test)] for t in test_index]) test_tgt_all.extend(test_tgt) x_batch = prepare_data_for_cnn(test_src, opt) y_batch = prepare_data_for_rnn( test_tgt, opt_t, is_add_GO=False ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn( test_tgt, opt_t) feed = {src_: x_batch, tgt_: y_batch} res = sess.run(res_, feed_dict=feed) res_all.extend(res['syn_sent']) test_set = [prepare_for_bleu(s) for s in test_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s, bleu2s, bleu3s, bleu4s] = cal_BLEU_4(gen, {0: test_set}, is_corpus=opt.is_corpus) [rouge1, rouge2, rouge3, rouge4, rougeL, rouges] = cal_ROUGE(gen, {0: test_set}, is_corpus=opt.is_corpus) etp_score, dist_score = cal_entropy(gen) bleu_nltk = cal_BLEU_4_nltk(gen, test_set, is_corpus=opt.is_corpus) rel_score = cal_relevance(gen, test_set, embedding) print 'Test BLEU: ' + ' '.join([ str(round(it, 3)) for it in (bleu_nltk, bleu1s, bleu2s, bleu3s, bleu4s) ]) print 'Test Rouge: ' + ' '.join([ str(round(it, 3)) for it in (rouge1, rouge2, rouge3, rouge4) ]) print 'Test Entropy: ' + ' '.join([ str(round(it, 3)) for it in (etp_score[0], etp_score[1], etp_score[2], etp_score[3]) ]) print 'Test Diversity: ' + ' '.join([ str(round(it, 3)) for it in (dist_score[0], dist_score[1], dist_score[2], dist_score[3]) ]) print 'Test Relevance(G,A,E): ' + ' '.join([ str(round(it, 3)) for it in (rel_score[0], rel_score[1], rel_score[2]) ]) print 'Test Avg. length: ' + str( round( np.mean([ len([y for y in x if y != 0]) for x in res_all ]), 3)) print '' if TEST_FLAG: exit() tgt, src = zip(*[train[t] for t in train_index]) x_batch = prepare_data_for_cnn(src, opt) # Batch L y_batch = prepare_data_for_rnn( tgt, opt_t, is_add_GO=False ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn( tgt, opt_t) feed = {src_: x_batch, tgt_: y_batch} if uidx % opt.d_freq == 1: if profile: _, loss_d = sess.run( [train_op_d, gan_cost_d_], feed_dict=feed, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) else: _, loss_d = sess.run([train_op_d, gan_cost_d_], feed_dict=feed) if uidx % opt.g_freq == 1: if profile: _, loss_g = sess.run( [train_op_g, gan_cost_g_], feed_dict=feed, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) else: _, loss_g = sess.run([train_op_g, gan_cost_g_], feed_dict=feed) if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer. PRINT_ALL_TIMING_MEMORY) exit(0) if uidx % opt.valid_freq == 1: VALID_SIZE = 1024 valid_multiplier = np.int( np.floor(VALID_SIZE / opt.batch_size)) res_all, val_tgt_all, loss_val_d_all, loss_val_g_all = [], [], [], [] for val_step in range(valid_multiplier): valid_index = np.random.choice(len(val), opt.batch_size) val_tgt, val_src = zip(*[val[t] for t in valid_index]) val_tgt_all.extend(val_tgt) x_val_batch = prepare_data_for_cnn(val_src, opt) # Batch L y_val_batch = prepare_data_for_rnn( val_tgt, opt_t, is_add_GO=False ) if opt.model == 'cnn_rnn' else prepare_data_for_cnn( val_tgt, opt_t) feed_val = {src_: x_val_batch, tgt_: y_val_batch} loss_val_d, loss_val_g = sess.run( [gan_cost_d_, gan_cost_g_], feed_dict=feed_val) loss_val_d_all.append(loss_val_d) loss_val_g_all.append(loss_val_g) res = sess.run(res_, feed_dict=feed_val) res_all.extend(res['syn_sent']) print("Validation: loss D %f loss G %f " % (np.mean(loss_val_d_all), np.mean(loss_val_g_all))) #print "Val Perm :" + " ".join([ixtoword[x] for x in val_src_permutated[0] if x != 0]) print "Val Source:" + u' '.join([ ixtoword[x] for x in val_src[0] if x != 0 ]).encode('utf-8').strip() print "Val Target :" + u' '.join([ ixtoword[x] for x in val_tgt[0] if x != 0 ]).encode('utf-8').strip() print "Val Generated:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ]).encode('utf-8').strip() print "" val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s, bleu2s, bleu3s, bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus=opt.is_corpus) [rouge1, rouge2, rouge3, rouge4, rougeL, rouges] = cal_ROUGE(gen, {0: val_set}, is_corpus=opt.is_corpus) etp_score, dist_score = cal_entropy(gen) bleu_nltk = cal_BLEU_4_nltk(gen, val_set, is_corpus=opt.is_corpus) rel_score = cal_relevance(gen, val_set, embedding) print 'Val BLEU: ' + ' '.join([ str(round(it, 3)) for it in (bleu_nltk, bleu1s, bleu2s, bleu3s, bleu4s) ]) print 'Val Rouge: ' + ' '.join([ str(round(it, 3)) for it in (rouge1, rouge2, rouge3, rouge4) ]) print 'Val Entropy: ' + ' '.join([ str(round(it, 3)) for it in (etp_score[0], etp_score[1], etp_score[2], etp_score[3]) ]) print 'Val Diversity: ' + ' '.join([ str(round(it, 3)) for it in (dist_score[0], dist_score[1], dist_score[2], dist_score[3]) ]) print 'Val Relevance(G,A,E): ' + ' '.join([ str(round(it, 3)) for it in (rel_score[0], rel_score[1], rel_score[2]) ]) print 'Val Avg. length: ' + str( round( np.mean([ len([y for y in x if y != 0]) for x in res_all ]), 3)) print "" summary = sess.run(merged, feed_dict=feed_val) summary2 = tf.Summary(value=[ tf.Summary.Value(tag="bleu-2", simple_value=bleu2s), tf.Summary.Value(tag="rouge-2", simple_value=rouge2), tf.Summary.Value(tag="etp-4", simple_value=etp_score[3]) ]) test_writer.add_summary(summary, uidx) test_writer.add_summary(summary2, uidx) if uidx % opt.print_freq == 1: print("Iteration %d: loss D %f loss G %f" % (uidx, loss_d, loss_g)) res = sess.run(res_, feed_dict=feed) if opt.grad_penalty: print "grad_penalty: " + str(res['gp']) print "Source:" + u' '.join([ ixtoword[x] for x in x_batch[0] if x != 0 ]).encode('utf-8').strip() print "Target:" + u' '.join([ ixtoword[x] for x in y_batch[0] if x != 0 ]).encode('utf-8').strip() print "Generated:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ]).encode('utf-8').strip() print "" sys.stdout.flush() summary = sess.run(merged, feed_dict=feed) train_writer.add_summary(summary, uidx) if uidx % opt.save_freq == 1: saver.save(sess, opt.save_path)
def run_model(opt, train, val, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) # config.gpu_options.per_process_gpu_memory_fraction = 0.8 config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) print('Load pretrain successfully') except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) #sents[0] = np.random.permutation(sents[0]) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO=False) # Batch L if profile: _, loss = sess.run( [train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) else: _, loss = sess.run([train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) if uidx % opt.valid_freq == 0: is_train = None valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn( val_sents_permutated, opt, is_add_GO=False) loss_val = sess.run(loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) print("Validation loss %f " % (loss_val)) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) np.savetxt(opt.save_txt + '/rec_val_words.txt', res['rec_sents'], fmt='%i', delimiter=' ') try: print("Orig:" + u' '.join([ ixtoword[x] for x in x_val_batch_org[0] if x != 0 and x != 1 ])) #.encode('utf-8', 'ignore').strip() print("Sent:" + u' '.join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ])) #.encode('utf-8', 'ignore').strip() except: pass if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) summary = sess.run(merged, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) test_writer.add_summary(summary, uidx) is_train = True if uidx % opt.print_freq == 1: #pdb.set_trace() print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) np.savetxt(opt.save_txt + '/rec_train_words.txt', res['rec_sents'], fmt='%i', delimiter=' ') try: print("Orig:" + u' '.join([ ixtoword[x] for x in x_batch_org[0] if x != 0 and x != 1 ])) #.encode('utf-8').strip() print("Sent:" + u' '.join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ])) #.encode('utf-8').strip() except: pass summary = sess.run(merged, feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer. PRINT_ALL_TIMING_MEMORY) saver.save(sess, opt.save_path, global_step=epoch)
def main(): opt = COptions(args) opt_t = COptions(args) loadpath = (opt.data_dir + "/" + opt.data_name) print "loadpath:" + loadpath x = cPickle.load(open(loadpath, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] if opt.test: test_file = opt.data_dir + opt.test_file test = read_test(test_file, wordtoix) opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params(args) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) for d in ['/gpu:0']: with tf.device(d): src_ = [tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context)] tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) is_train_ = tf.placeholder(tf.bool, name = 'is_train') res_1_ = get_features(src_, tgt_, is_train_, opt, opt_t) merged = tf.summary.merge_all() uidx = 0 graph_options=tf.GraphOptions(build_cost_model=1) config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options) np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() if opt.load_from_pretrain: d_vars = [var for var in t_vars if var.name.startswith('d_')] l_vars = [var for var in t_vars if var.name.startswith('l_')] restore_from_save(d_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.global_d) if opt.local_feature: restore_from_save(l_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.local_d) else: loader = restore_from_save(t_vars, sess, opt, load_path = opt.save_path) except Exception as e: print 'Error: '+str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d , loss_g = 0, 0 if opt.test: iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 z_all, z_all_l = [], [] for i in range(iter_num): test_index = range(i * opt.batch_size,(i+1) * opt.batch_size) sents = [test[t%len(test)] for t in test_index] src = [[sents[i][0] for i in range(opt.batch_size)]] tgt = [sents[i][0] for i in range(opt.batch_size)] x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] print "Source:" + u' '.join([ixtoword[x] for s in x_batch for x in s[0] if x != 0]).encode('utf-8').strip() y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) res_1 = sess.run(res_1_, feed_dict=feed) z_all.extend(res_1['z']) z_all_l.extend(res_1['z_l']) save_path_z = opt.log_path + '.global.z.txt' print save_path_z if os.path.exists(save_path_z): os.remove(save_path_z) with open(save_path_z, "a") as myfile: for line in z_all[:len(test)]: for z_it in line: myfile.write(str(z_it) + '\t') myfile.write('\n') save_path_z = opt.log_path + '.local.z.txt' print save_path_z if os.path.exists(save_path_z): os.remove(save_path_z) with open(save_path_z, "a") as myfile: for line in z_all_l[:len(test)]: for z_it in line: myfile.write(str(z_it) + '\t') myfile.write('\n')