def run_model(opt, X): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:0'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) if opt.plot_type == 'ae': x_lat_ = ae(x_, opt) elif opt.plot_type == 'vae' or opt.plot_type == 'cyc': mu_, z_ = vae(x_, opt) x_lat_ = z_ if opt.use_z else mu_ config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.3 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) X_emb = np.zeros([len(X), opt.z_dim], dtype='float32') kf = get_minibatches_idx(len(X), opt.batch_size) t = 0 for _, index in kf: sents_b = [X[i] for i in index] x_b = prepare_data_for_cnn(sents_b, opt) x_lat = np.squeeze(sess.run(x_lat_, feed_dict={x_: x_b})) X_emb[t * opt.batch_size:(t + 1) * opt.batch_size] = x_lat if (t + 1) % 10 == 0: print('%d / %d' % (t + 1, len(kf))) t += 1 return X_emb
def run_model(opt, train, val, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) # config.gpu_options.per_process_gpu_memory_fraction = 0.8 config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) print('Load pretrain successfully') except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) #sents[0] = np.random.permutation(sents[0]) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO=False) # Batch L if profile: _, loss = sess.run( [train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) else: _, loss = sess.run([train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) if uidx % opt.valid_freq == 0: is_train = None valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn( val_sents_permutated, opt, is_add_GO=False) loss_val = sess.run(loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) print("Validation loss %f " % (loss_val)) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) np.savetxt(opt.save_txt + '/rec_val_words.txt', res['rec_sents'], fmt='%i', delimiter=' ') try: print("Orig:" + u' '.join([ ixtoword[x] for x in x_val_batch_org[0] if x != 0 and x != 1 ])) #.encode('utf-8', 'ignore').strip() print("Sent:" + u' '.join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ])) #.encode('utf-8', 'ignore').strip() except: pass if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) summary = sess.run(merged, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) test_writer.add_summary(summary, uidx) is_train = True if uidx % opt.print_freq == 1: #pdb.set_trace() print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) np.savetxt(opt.save_txt + '/rec_train_words.txt', res['rec_sents'], fmt='%i', delimiter=' ') try: print("Orig:" + u' '.join([ ixtoword[x] for x in x_batch_org[0] if x != 0 and x != 1 ])) #.encode('utf-8').strip() print("Sent:" + u' '.join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ])) #.encode('utf-8').strip() except: pass summary = sess.run(merged, feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer. PRINT_ALL_TIMING_MEMORY) saver.save(sess, opt.save_path, global_step=epoch)
def run_model(opt, train, val, test, train_lab, val_lab, test_lab, wordtoix, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False min_val_loss = 1e30 min_test_loss = 1e30 max_val_accuracy = 0. max_test_accuracy = 0. best_epoch = -1 epoch_t = len(train) // opt.batch_size with tf.device('/gpu:0'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) y_ = tf.placeholder(tf.float32, shape=[opt.batch_size, 1]) is_train_ = tf.placeholder(tf.bool, name='is_train_') loss_, train_op_, accuracy_ = ae(x_, y_, is_train_, opt, epoch_t) merged = tf.summary.merge_all() config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.3 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() def run_epoch(sess, epoch, mode, print_freq=-1, train_writer=None): fetches_ = {'loss': loss_, 'accuracy': accuracy_} if mode == 'train': x, y, is_train = train, train_lab, 1 fetches_['train_op'] = train_op_ fetches_['summary'] = merged elif mode == 'val': assert (print_freq == -1) x, y, is_train = val, val_lab, None elif mode == 'test': assert (print_freq == -1) x, y, is_train = test, test_lab, None correct, acc_loss, acc_n = 0.0, 0.0, 0.0 local_t = 0 global_t = epoch * epoch_t # only used in train mode start_time = time.time() kf = get_minibatches_idx(len(x), opt.batch_size, shuffle=True) for _, index in kf: local_t += 1 global_t += 1 sents_b = [x[i] for i in index] sents_b_n = add_noise(sents_b, opt) y_b = [y[i] for i in index] y_b = np.array(y_b) y_b = y_b.reshape((len(y_b), 1)) x_b = prepare_data_for_cnn(sents_b_n, opt) # Batch L feed_t = {x_: x_b, y_: y_b, is_train_: is_train} fetches = sess.run(fetches_, feed_dict=feed_t) batch_size = len(index) acc_n += batch_size acc_loss += fetches['loss'] * batch_size correct += fetches['accuracy'] * batch_size if print_freq > 0 and local_t % print_freq == 0: print("%s Iter %d: loss %.4f, acc %.4f, time %.1fs" % (mode, local_t, acc_loss / acc_n, correct / acc_n, time.time() - start_time)) if mode == 'train' and train_writer != None: train_writer.add_summary(fetches['summary'], global_t) print("%s Epoch %d: loss %.4f, acc %.4f, time %.1fs" % (mode, epoch, acc_loss / acc_n, correct / acc_n, time.time() - start_time)) return acc_loss / acc_n, correct / acc_n with tf.Session(config=config) as sess: # writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) _, _ = run_epoch(sess, epoch, 'train') val_loss, val_accuracy = run_epoch(sess, epoch, 'val') test_loss, test_accuracy = run_epoch(sess, epoch, 'test') # if val_loss < min_val_loss: # min_val_loss = val_loss # min_test_loss = test_loss # best_epoch = epoch # max_test_accuracy = test_accuracy # saver.save(sess, opt.save_path+"_cla") if val_accuracy > max_val_accuracy: max_val_accuracy = val_accuracy min_test_loss = test_loss best_epoch = epoch max_test_accuracy = test_accuracy saver.save(sess, opt.save_path + "_cla") if opt.save_freq_ep > 0 and (epoch + 1) % opt.save_freq_ep == 0: saver.save(sess, opt.save_path + "_cla", global_step=epoch) if opt.save_last: saver.save(sess, opt.save_path + '_cla_last') # print("Min Val Loss %.4f, Min Test Loss %.4f, Max Test Acc %.4f, Best Epoch %d\n" % # (min_val_loss, min_test_loss, max_test_accuracy, best_epoch)) print( "Max Val Acc %.4f, Min Test Loss %.4f, Max Test Acc %.4f, Best Epoch %d\n" % (max_val_accuracy, min_test_loss, max_test_accuracy, best_epoch))
def run_model(opt, train, val, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:0'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("\nStarting epoch %d\n" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: print "\rIter: %d" % uidx, uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) #sents[0] = np.random.permutation(sents[0]) x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L if x_batch.shape[0] == opt.batch_size: d_loss = 0 g_loss = 0 if profile: if uidx % opt.dis_steps == 0: _, d_loss = sess.run([dis_op, d_loss_], feed_dict={x_: x_batch},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata) if uidx % opt.gen_steps == 0: _, g_loss = sess.run([gen_op, g_loss_], feed_dict={x_: x_batch},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata) else: if uidx % opt.dis_steps == 0: _, d_loss = sess.run([dis_op, d_loss_], feed_dict={x_: x_batch}) if uidx % opt.gen_steps == 0: _, g_loss = sess.run([gen_op, g_loss_], feed_dict={x_: x_batch}) if uidx % opt.valid_freq == 0: is_train = True # print('Valid Size:', len(val)) valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) d_loss_val = sess.run(d_loss_, feed_dict={x_: x_val_batch}) g_loss_val = sess.run(g_loss_, feed_dict={x_: x_val_batch}) res = sess.run(res_, feed_dict={x_: x_val_batch}) print("Validation d_loss %f, g_loss %f mean_dist %f" % (d_loss_val, g_loss_val, res['mean_dist'])) print "Sent:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip() print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan'])) np.savetxt('./text/rec_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ') if opt.discrimination: print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['syn_sent']], {0: val_set}) print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]) summary = sess.run(merged, feed_dict={x_: x_val_batch}) test_writer.add_summary(summary, uidx) if uidx % opt.print_freq == 0: #pdb.set_trace() res = sess.run(res_, feed_dict={x_: x_batch}) median_dis = np.sqrt(np.median([((x-y)**2).sum() for x in res['real_f'] for y in res['real_f']])) print("Iteration %d: d_loss %f, g_loss %f, mean_dist %f, realdist median %f" % (uidx, d_loss, g_loss, res['mean_dist'], median_dis)) np.savetxt('./text/rec_train_words.txt', res['syn_sent'], fmt='%i', delimiter=' ') print "Sent:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip() summary = sess.run(merged, feed_dict={x_: x_batch}) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY) saver.save(sess, opt.save_path, global_step=epoch)
def run_model(opt, train, val, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) # is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.per_process_gpu_memory_fraction = 0.8 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) print('\nload successfully\n') except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) # for i in range(34): # valid_index = np.random.choice( # len(val), opt.batch_size) # val_sents = [val[t] for t in valid_index] # val_sents_permutated = add_noise(val_sents, opt) # x_val_batch = prepare_data_for_cnn( # val_sents_permutated, opt) # x_val_batch_org = prepare_data_for_rnn(val_sents, opt) # res = sess.run(res_, feed_dict={ # x_: x_val_batch, x_org_: x_val_batch_org}) # if i == 0: # valid_text = res['syn_sent'] # else: # valid_text = np.concatenate( # (valid_text, res['syn_sent']), 0) # np.savetxt('./text_news/vae_words.txt', valid_text, fmt='%i', delimiter=' ') # pdb.set_trace() for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) #sents[0] = np.random.permutation(sents[0]) x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L x_batch_org = prepare_data_for_rnn(sents, opt) d_loss = 0 g_loss = 0 if profile: if uidx % opt.dis_steps == 0: _, d_loss = sess.run( [dis_op, d_loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) if uidx % opt.gen_steps == 0: _, g_loss = sess.run( [gen_op, g_loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) else: if uidx % opt.dis_steps == 0: _, d_loss = sess.run([dis_op, d_loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }) if uidx % opt.gen_steps == 0: _, g_loss = sess.run([gen_op, g_loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }) ''' validation ''' if uidx % opt.valid_freq == 0: valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) x_val_batch_org = prepare_data_for_rnn(val_sents, opt) d_loss_val = sess.run(d_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) g_loss_val = sess.run(g_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) print("Validation d_loss %f, g_loss %f mean_dist %f" % (d_loss_val, g_loss_val, res['mean_dist'])) print("Sent:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ])) #.encode('utf-8', 'ignore').decode("utf8").strip()) print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan'])) # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ') if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) for i in range(4): valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) x_val_batch_org = prepare_data_for_rnn(val_sents, opt) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) if i == 0: valid_text = res['syn_sent'] else: valid_text = np.concatenate( (valid_text, res['syn_sent']), 0) np.savetxt('./text_news/syn_val_words.txt', valid_text, fmt='%i', delimiter=' ') val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU( [prepare_for_bleu(s) for s in res['syn_sent']], {0: val_set}) print('Val BLEU (2,3,4): ' + ' '.join( [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)])) summary = sess.run(merged, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) test_writer.add_summary(summary, uidx)
def run_model(opt, train, val, test, test_lab, wordtoix, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False min_val_loss = 1e50 min_test_loss = 1e50 best_epoch = -1 epoch_t = len(train) // opt.batch_size with tf.device('/gpu:0'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, loss_, train_op_ = ae(x_, x_org_, is_train_, opt, epoch_t) merged = tf.summary.merge_all() config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.3 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() def run_epoch(sess, epoch, mode, print_freq=-1, display_sent=-1, train_writer=None): fetches_ = {'loss': loss_} if mode == 'train': x, is_train = train, 1 fetches_['train_op'] = train_op_ fetches_['summary'] = merged elif mode == 'val': assert (print_freq == -1) x, is_train = val, None elif mode == 'test': assert (print_freq == -1) x, is_train = test, None acc_loss, acc_n = 0.0, 0.0 local_t = 0 global_t = epoch * epoch_t # only used in train mode start_time = time.time() kf = get_minibatches_idx(len(x), opt.batch_size, shuffle=True) for _, index in kf: local_t += 1 global_t += 1 sents_b = [x[i] for i in index] sents_b_n = add_noise(sents_b, opt) x_b_org = prepare_data_for_rnn(sents_b, opt) # Batch L x_b = prepare_data_for_cnn(sents_b_n, opt) # Batch L feed_t = {x_: x_b, x_org_: x_b_org, is_train_: is_train} fetches = sess.run(fetches_, feed_dict=feed_t) batch_size = len(index) acc_n += batch_size acc_loss += fetches['loss'] * batch_size if print_freq > 0 and local_t % print_freq == 0: print("%s Iter %d: loss %.4f, time %.1fs" % (mode, local_t, acc_loss / acc_n, time.time() - start_time)) sys.stdout.flush() if mode == 'train' and train_writer != None: train_writer.add_summary(fetches['summary'], global_t) if display_sent > 0: index_d = np.random.choice(len(x), opt.batch_size, replace=False) sents_d = [x[i] for i in index_d] sents_d_n = add_noise(sents_d, opt) x_d_org = prepare_data_for_rnn(sents_d, opt) # Batch L x_d = prepare_data_for_cnn(sents_d_n, opt) # Batch L res = sess.run(res_, feed_dict={ x_: x_d, x_org_: x_d_org, is_train_: is_train }) for i in range(display_sent): print( "%s Org: " % mode + " ".join([ ixtoword[ix] for ix in sents_d[i] if ix != 0 and ix != 2 ])) if mode == 'train': print( "%s Rec(feedy): " % mode + " ".join([ ixtoword[ix] for ix in res['rec_sents_feed_y'][i] if ix != 0 and ix != 2 ])) print( "%s Rec: " % mode + " ".join([ ixtoword[ix] for ix in res['rec_sents'][i] if ix != 0 and ix != 2 ])) print("%s Epoch %d: loss %.4f, time %.1fs" % (mode, epoch, acc_loss / acc_n, time.time() - start_time)) return acc_loss / acc_n with tf.Session(config=config) as sess: # writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) _ = run_epoch(sess, epoch, 'train', opt.print_freq, display_sent=1) val_loss = run_epoch(sess, epoch, 'val', display_sent=1) test_loss = run_epoch(sess, epoch, 'test', display_sent=1) if val_loss < min_val_loss: min_val_loss = val_loss best_epoch = epoch min_test_loss = test_loss saver.save(sess, opt.save_path) if opt.save_freq_ep > 0 and (epoch + 1) % opt.save_freq_ep == 0: saver.save(sess, opt.save_path, global_step=epoch) if opt.save_last: saver.save(sess, opt.save_path + '_last') print("Min Val Loss %.4f, Min Test Loss %.4f, Best Epoch %d\n" % (min_val_loss, min_test_loss, best_epoch))
def main(): opt = Options(args) opt_t = Options(args) opt_t.n_hid = opt.n_z loadpath = (opt.data_dir + "/" + opt.data_name) print "loadpath:" + loadpath x = cPickle.load(open(loadpath, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] opt.n_words = len(ixtoword) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) print dict(opt) if opt.model == 'cnn_rnn': opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params(args) print dict(opt_t) print('Total words: %d' % opt.n_words) for d in ['/gpu:0']: with tf.device(d): src_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) z_ = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.n_z]) res_, gan_cost_g_, train_op_g = conditional_s2s(src_, tgt_, z_, opt, opt_t) merged = tf.summary.merge_all() uidx = 0 graph_options=tf.GraphOptions(build_cost_model=1) config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options) config.gpu_options.per_process_gpu_memory_fraction = 0.90 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() if opt.load_from_pretrain: d_vars = [var for var in t_vars if var.name.startswith('d_')] g_vars = [var for var in t_vars if var.name.startswith('g_')] restore_from_save(g_vars, sess, opt, opt.restore_dir + "/save/generator") restore_from_save(d_vars, sess, opt, opt.restore_dir + "/save/discriminator") else: loader = restore_from_save(t_vars, sess, opt) except Exception as e: print 'Error: '+str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d , loss_g = 0, 0 for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 indice = [(x,x+1) for x in range(7)] sents = [train[t] for t in train_index] for idx in indice: src = [sents[i][idx[0]] for i in range(opt.batch_size)] tgt = [sents[i][idx[1]] for i in range(opt.batch_size)] src_permutated = src x_batch = prepare_data_for_cnn(src_permutated, opt) y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) if opt.z_prior == 'g': z_batch = np.random.normal(0,1,(opt.n_hid, opt.n_z)).astype('float32') else: z_batch = np.random.uniform(-1,1,(opt.batch_size, opt.n_z)).astype('float32') feed = {src_: x_batch, tgt_: y_batch, z_:z_batch} _, loss_g = sess.run([train_op_g, gan_cost_g_],feed_dict=feed) if uidx%opt.print_freq == 0: print("Iteration %d: loss G %f" %(uidx, loss_g)) res = sess.run(res_, feed_dict=feed) print "Source:" + u' '.join([ixtoword[x] for x in x_batch[0] if x != 0]).encode('utf-8').strip() print "Target:" + u' '.join([ixtoword[x] for x in y_batch[0] if x != 0]).encode('utf-8').strip() print "Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip() print "" sys.stdout.flush() summary = sess.run(merged, feed_dict=feed) train_writer.add_summary(summary, uidx) if uidx%opt.valid_freq == 1: VALID_SIZE = 4096 valid_multiplier = np.int(np.floor(VALID_SIZE/opt.batch_size)) res_all, val_tgt_all, loss_val_d_all, loss_val_g_all = [], [], [], [] if opt.global_feature: z_loss_all = [] for val_step in range(valid_multiplier): valid_index = np.random.choice(len(test), opt.batch_size) indice = [(x,x+1) for x in range(7)] val_sents = [test[t] for t in valid_index] for idx in indice: val_src = [val_sents[i][idx[0]] for i in range(opt.batch_size)] val_tgt = [val_sents[i][idx[1]] for i in range(opt.batch_size)] val_tgt_all.extend(val_tgt) val_src_permutated = val_src x_val_batch = prepare_data_for_cnn(val_src, opt) y_val_batch = prepare_data_for_rnn(val_src, opt_t, is_add_GO = False) if opt.model == 'cnn_rnn' else prepare_data_for_cnn(val_src, opt_t) if opt.z_prior == 'g': z_val_batch = np.random.normal(0,1,(opt.batch_size, opt.n_z)).astype('float32') else: z_val_batch = np.random.uniform(-1,1,(opt.batch_size, opt.n_z)).astype('float32') feed_val = {src_: x_val_batch, tgt_: y_val_batch, z_:z_val_batch} loss_val_g = sess.run([gan_cost_g_], feed_dict=feed_val) loss_val_g_all.append(loss_val_g) res = sess.run(res_, feed_dict=feed_val) res_all.extend(res['syn_sent']) if opt.global_feature: z_loss_all.append(res['z_loss']) print("Validation: loss G %f " %( np.mean(loss_val_g_all))) print "Val Source:" + u' '.join([ixtoword[x] for x in val_src[0] if x != 0]).encode('utf-8').strip() print "Val Target :" + u' '.join([ixtoword[x] for x in val_tgt[0] if x != 0]).encode('utf-8').strip() print "Val Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip() print "" if opt.global_feature: with open(opt.log_path + '.z.txt', "a") as myfile: myfile.write("Iteration" + str(uidx) + "\n") myfile.write("z_loss %f" %(np.mean(z_loss_all))+ "\n") myfile.write("Val Source:" + u' '.join([ixtoword[x] for x in val_src[0] if x != 0]).encode('utf-8').strip()+ "\n") myfile.write("Val Target :" + u' '.join([ixtoword[x] for x in val_tgt[0] if x != 0]).encode('utf-8').strip()+ "\n") myfile.write("Val Generated:" + u' '.join([ixtoword[x] for x in res['syn_sent'][0] if x != 0]).encode('utf-8').strip()+ "\n") myfile.write(np.array2string(res['z'][0], formatter={'float_kind':lambda x: "%.2f" % x})+ "\n") myfile.write(np.array2string(res['z_hat'][0], formatter={'float_kind':lambda x: "%.2f" % x})+ "\n\n") val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus) etp_score, dist_score = cal_entropy(gen) print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)]) print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])]) print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])]) print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) print "" summary = sess.run(merged, feed_dict=feed_val) summary2 = tf.Summary(value=[tf.Summary.Value(tag="bleu-2", simple_value=bleu2s),tf.Summary.Value(tag="etp-4", simple_value=etp_score[3])]) test_writer.add_summary(summary, uidx) test_writer.add_summary(summary2, uidx) if uidx%opt.save_freq == 0: saver.save(sess, opt.save_path)
def main(): # global n_words # Prepare training and testing data # loadpath = "./data/three_corpus_small.p" loadpath = "./data/hotel_reviews.p" x = cPickle.load(open(loadpath, "rb")) train, val = x[0], x[1] wordtoix, ixtoword = x[2], x[3] opt = Options() opt.n_words = len(ixtoword) + 1 ixtoword[opt.n_words - 1] = 'GO_' print dict(opt) print('Total words: %d' % opt.n_words) try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) res_, loss_, train_op = auto_encoder(x_, x_org_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() # tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 # writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) # config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: # pdb.set_trace() t_vars = tf.trainable_variables() # print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] if opt.substitution == 's': sents_permutated = substitute_sent(sents, opt) elif opt.substitution == 'p': sents_permutated = permutate_sent(sents, opt) elif opt.substitution == 'a': sents_permutated = add_sent(sents, opt) elif opt.substitution == 'd': sents_permutated = delete_sent(sents, opt) else: sents_permutated = sents # sents[0] = np.random.permutation(sents[0]) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO=False) # Batch L # x_print = sess.run([x_emb],feed_dict={x_: x_train} ) # print x_print # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org}) # pdb.set_trace() _, loss = sess.run([train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org }) if uidx % opt.valid_freq == 0: opt.is_train = False valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] if opt.substitution == 's': val_sents_permutated = substitute_sent(val_sents, opt) elif opt.substitution == 'p': val_sents_permutated = permutate_sent(val_sents, opt) elif opt.substitution == 'a': val_sents_permutated = add_sent(val_sents, opt) elif opt.substitution == 'd': val_sents_permutated = delete_sent(val_sents, opt) else: val_sents_permutated = sents if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn( val_sents_permutated, opt, is_add_GO=False) loss_val = sess.run(loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) print("Validation loss %f " % (loss_val)) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) print "Val Orig :" + " ".join( [ixtoword[x] for x in val_sents[0] if x != 0]) print "Val Perm :" + " ".join([ ixtoword[x] for x in val_sents_permutated[0] if x != 0 ]) print "Val Recon:" + " ".join( [ixtoword[x] for x in res['rec_sents'][0] if x != 0]) val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU( [prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set}) print 'Val BLEU (2,3,4): ' + ' '.join( [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]) summary = sess.run(merged, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) test_writer.add_summary(summary, uidx) opt.is_train = True if uidx % opt.print_freq == 0: print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict={ x_: x_batch, x_org_: x_batch_org }) print "Original :" + " ".join( [ixtoword[x] for x in sents[0] if x != 0]) print "Permutated :" + " ".join( [ixtoword[x] for x in sents_permutated[0] if x != 0]) if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': print "Reconstructed:" + " ".join([ ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0 ]) print "Reconstructed:" + " ".join( [ixtoword[x] for x in res['rec_sents'][0] if x != 0]) summary = sess.run(merged, feed_dict={ x_: x_batch, x_org_: x_batch_org }) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] saver.save(sess, opt.save_path, global_step=epoch)
def run_model(opt, train, val, test, wordtoix, ixtoword): with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt) merged = tf.summary.merge_all() summary_ext = tf.Summary() uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO=False) # Batch L # x_print = sess.run([x_emb],feed_dict={x_: x_train} ) # print x_print # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org}) # pdb.set_trace() # if profile: _, loss = sess.run( [train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) else: _, loss = sess.run([train_op, loss_], feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) #pdb.set_trace() if uidx % opt.valid_freq == 0: is_train = None valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn( val_sents_permutated, opt, is_add_GO=False) loss_val = sess.run(loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) print("Validation loss %f " % (loss_val)) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) if opt.char: print "Val Orig :" + "".join( [ixtoword[x] for x in val_sents[0] if x != 0]) print "Val Perm :" + "".join([ ixtoword[x] for x in val_sents_permutated[0] if x != 0 ]) print "Val Recon:" + "".join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ]) # print "Val Recon one hot:" + "".join([ixtoword[x] for x in res['rec_sents_one_hot'][0] if x != 0]) else: print "Val Orig :" + " ".join( [ixtoword[x] for x in val_sents[0] if x != 0]) print "Val Perm :" + " ".join([ ixtoword[x] for x in val_sents_permutated[0] if x != 0 ]) print "Val Recon:" + " ".join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ]) val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU( [prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set}) print 'Val BLEU (2,3,4): ' + ' '.join( [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]) val_set_char = [ prepare_for_cer(s, ixtoword) for s in val_sents ] cer = cal_cer([ prepare_for_cer(s, ixtoword) for s in res['rec_sents'] ], val_set_char) print 'Val CER: ' + str(round(cer, 3)) # summary_ext.Value(tag='CER', simple_value=cer) summary_ext = tf.Summary( value=[tf.Summary.Value(tag='CER', simple_value=cer)]) # tf.summary.scalar('CER', cer) #if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': #print "Gen Probs:" + " ".join([str(np.round(res['gen_p'][i], 1)) for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) summary = sess.run(merged, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, is_train_: is_train }) test_writer.add_summary(summary, uidx) test_writer.add_summary(summary_ext, uidx) is_train = True if uidx % opt.print_freq == 0: print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) # if 1 in res['rec_sents'][0] or 1 in sents[0]: # pdb.set_trace() if opt.char: print "Original :" + "".join( [ixtoword[x] for x in sents[0] if x != 0]) print "Permutated :" + "".join([ ixtoword[x] for x in sents_permutated[0] if x != 0 ]) if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': print "Reconstructed:" + " ".join([ ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0 ]) print "Reconstructed:" + "".join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ]) else: print "Original :" + " ".join( [ixtoword[x] for x in sents[0] if x != 0]) print "Permutated :" + " ".join([ ixtoword[x] for x in sents_permutated[0] if x != 0 ]) if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': print "Reconstructed:" + " ".join([ ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0 ]) print "Reconstructed:" + " ".join([ ixtoword[x] for x in res['rec_sents'][0] if x != 0 ]) summary = sess.run(merged, feed_dict={ x_: x_batch, x_org_: x_batch_org, is_train_: 1 }) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer. PRINT_ALL_TIMING_MEMORY) saver.save(sess, opt.save_path)
def run_model(opt, train_unlab_x, train_lab_x, train_lab, val_unlab_x, val_lab_x, val_lab, test, test_y, wordtoix, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str( (opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): alpha_ = tf.placeholder(tf.float32, shape=()) x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_lab_ = tf.placeholder(tf.int32, shape=[opt.dis_batch_size, opt.sent_len]) y_ = tf.placeholder(tf.float32, shape=[opt.dis_batch_size, 1]) dp_ratio_ = tf.placeholder(tf.float32, name='dp_ratio_') res_, dis_loss_, rec_loss_, loss_, train_op, prob_, acc_ = semi_classifier( alpha_, x_, x_org_, x_lab_, y_, dp_ratio_, opt) merged = tf.summary.merge_all() uidx = 0 max_val_accuracy = 0.0 max_test_accuracy = 0.0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) # config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train_unlab_x), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 if opt.rec_alpha > 0 and uidx > opt.pretrain_step and uidx % opt.rec_decay_freq == 0: opt.rec_alpha -= 0.01 print "alpha: " + str(opt.rec_alpha) sents = [train_unlab_x[t] for t in train_index] lab_index = np.random.choice(len(train_lab), opt.dis_batch_size, replace=False) lab_sents = [train_lab_x[t] for t in lab_index] batch_lab = [train_lab[t] for t in lab_index] batch_lab = np.array(batch_lab) batch_lab = batch_lab.reshape((len(batch_lab), 1)) x_batch_lab = prepare_data_for_cnn(lab_sents, opt) sents_permutated = add_noise(sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO=False) # Batch L _, dis_loss, rec_loss, loss, acc = sess.run( [train_op, dis_loss_, rec_loss_, loss_, acc_], feed_dict={ alpha_: opt.rec_alpha, x_: x_batch, x_org_: x_batch_org, x_lab_: x_batch_lab, y_: batch_lab, dp_ratio_: opt.dropout_ratio }) summary = sess.run(merged, feed_dict={ alpha_: opt.rec_alpha, x_: x_batch, x_org_: x_batch_org, x_lab_: x_batch_lab, y_: batch_lab, dp_ratio_: opt.dropout_ratio }) train_writer.add_summary(summary, uidx) if uidx % opt.print_freq == 0: print( "Iteration %d: dis_loss %f, rec_loss %f, loss %f, acc %f " % (uidx, dis_loss, rec_loss, loss, acc)) if uidx % opt.valid_freq == 0: #print("Iteration %d: dis_loss %f, rec_loss %f, loss %f " % (uidx, dis_loss, rec_loss, loss)) valid_index = np.random.choice(len(val_unlab_x), opt.batch_size) val_sents = [val_unlab_x[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn( val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn( val_sents_permutated, opt, is_add_GO=False) rec_loss_val = sess.run(rec_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org, dp_ratio_: 1.0 }) print("Validation rec loss %f " % rec_loss_val) kf_val = get_minibatches_idx(len(val_lab_x), opt.dis_batch_size, shuffle=False) prob_val = [] for _, val_ind in kf_val: val_sents = [val_lab_x[t] for t in val_ind] x_val_dis = prepare_data_for_cnn(val_sents, opt) val_y = np.array([val_lab[t] for t in val_ind]).reshape( (opt.dis_batch_size, 1)) val_prob = sess.run(prob_, feed_dict={ x_lab_: x_val_dis, dp_ratio_: 1.0 }) for x in val_prob: prob_val.append(x) ##### DON'T UNDERSTAND :error val_index # probs = [] # val_truth = [] # for i in range(len(val_lab)): # val_truth.append(val_lab[i]) # if type(val_index[i]) != int: # temp = [] # for j in val_index[i]: # temp.append(prob_val[j]) # aver = sum(temp) * 1.0 / len(temp) # probs.append(aver) # else: # probs.append(prob_val[val_index[i]]) probs = [] val_truth = [] for i in range(len(prob_val)): val_truth.append(val_lab[i]) probs.append(prob_val[i]) count = 0.0 for i in range(len(probs)): p = probs[i] if p > 0.5: if val_truth[i] == 1: count += 1.0 else: if val_truth[i] == 0: count += 1.0 val_accuracy = count * 1.0 / len(probs) print("Validation accuracy %f " % val_accuracy) summary = sess.run(merged, feed_dict={ alpha_: opt.rec_alpha, x_: x_val_batch, x_org_: x_val_batch_org, x_lab_: x_val_dis, y_: val_y, dp_ratio_: 1.0 }) test_writer.add_summary(summary, uidx) if val_accuracy >= max_val_accuracy: max_val_accuracy = val_accuracy kf_test = get_minibatches_idx(len(test), opt.dis_batch_size, shuffle=False) prob_test = [] for _, test_ind in kf_test: test_sents = [test[t] for t in test_ind] x_test_batch = prepare_data_for_cnn( test_sents, opt) test_prob = sess.run(prob_, feed_dict={ x_lab_: x_test_batch, dp_ratio_: 1.0 }) for x in test_prob: prob_test.append(x) probs = [] test_truth = [] for i in range(len(prob_test)): test_truth.append(test_y[i]) probs.append(prob_test[i]) # probs = [] # test_truth = [] # for i in range(len(test_y)): # test_truth.append(test_y[i]) # if type(test_index[i]) != int: # temp = [prob_test[j] for j in test_index[i]] # aver = sum(temp) * 1.0 / len(temp) # probs.append(aver) # else: # probs.append(prob_test[test_index[i]]) count = 0.0 for i in range(len(probs)): p = probs[i] if p > 0.5: if test_truth[i] == 1.0: count += 1.0 else: if test_truth[i] == 0.0: count += 1.0 test_accuracy = count * 1.0 / len(probs) print("Test accuracy %f " % test_accuracy) max_test_accuracy = test_accuracy def test_input(text): x_input = sent2idx(text, wordtoix, opt) res = sess.run(res_, feed_dict={ x_: x_input, x_org_: x_batch_org }) print "Reconstructed:" + " ".join( [ixtoword[x] for x in res['rec_sents'][0] if x != 0]) # res = sess.run(res_, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_: 1}) # print "Original :" + " ".join([ixtoword[x] for x in sents[0] if x != 0]) # # print "Permutated :" + " ".join([ixtoword[x] for x in sents_permutated[0] if x != 0]) # if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': # print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0]) # print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0]) # print "Probs:" + " ".join([ixtoword[res['rec_sents'][0][i]] +'(' +str(np.round(res['all_p'][i],2))+')' for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) print(opt.rec_alpha) print("Epoch %d: Max Valid accuracy %f" % (epoch, max_val_accuracy)) print("Epoch %d: Max Test accuracy %f" % (epoch, max_test_accuracy)) saver.save(sess, opt.save_path, global_step=epoch)
def run_model(opt, train, val, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:' + str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) res_, g_loss_, d_loss_, gen_op, dis_op = textGAN(x_, x_org_, opt) merged = tf.summary.merge_all() uidx = 0 config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.per_process_gpu_memory_fraction = 0.8 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) # keep all checkpoints saver = tf.train.Saver(max_to_keep=None) run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) print('\nload successfully\n') except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) ''' validation ''' valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) x_val_batch_org = prepare_data_for_rnn(val_sents, opt) d_loss_val = sess.run(d_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) g_loss_val = sess.run(g_loss_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) try: print("Validation d_loss %f, g_loss %f mean_dist %f" % (d_loss_val, g_loss_val, res['mean_dist'])) print("Sent:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ])) #.encode('utf-8', 'ignore').decode("utf8").strip()) print("MMD loss %f, GAN loss %f" % (res['mmd'], res['gan'])) except Exception as e: print(e) # np.savetxt('./text_arxiv/syn_val_words.txt', res['syn_sent'], fmt='%i', delimiter=' ') if opt.discrimination: print("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) for i in range(268590 // 1000 + 1): # generate 10k sentences # generate 268590 valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) x_val_batch_org = prepare_data_for_rnn(val_sents, opt) res = sess.run(res_, feed_dict={ x_: x_val_batch, x_org_: x_val_batch_org }) if i == 0: valid_text = res['syn_sent'] else: valid_text = np.concatenate((valid_text, res['syn_sent']), 0) valid_text = valid_text[:268590] np.savetxt(PATH_TO_SAVE, valid_text, fmt='%i', delimiter=' ') print('saved!\n\n\n') exit() val_set = [prepare_for_bleu(s) for s in val] #val_sents] bleu_prepared = [prepare_for_bleu(s) for s in res['syn_sent']] for i in range(len(val_set) // opt.batch_size): batch = val_set[i * opt.batch_size:(i + 1) * opt.batch_size] [bleu2s, bleu3s, bleu4s] = cal_BLEU(bleu_prepared, {0: batch}) #val_set}) print('Val BLEU (2,3,4): ' + ' '.join( [str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]))
def run_model(opt, train, val, test, wordtoix, ixtoword): try: params = np.load('./param_g.npz') if params['Wemb'].shape == (opt.n_words, opt.embed_size): print('Use saved embedding.') opt.W_emb = params['Wemb'] else: print('Emb Dimension mismatch: param_g.npz:'+ str(params['Wemb'].shape) + ' opt: ' + str((opt.n_words, opt.embed_size))) opt.fix_emb = False except IOError: print('No embedding file found.') opt.fix_emb = False with tf.device('/gpu:1'): x_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) x_org_ = tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train_') res_, loss_, train_op = auto_encoder(x_, x_org_, is_train_, opt) merged = tf.summary.merge_all() # opt.is_train = False # res_val_, loss_val_, _ = auto_encoder(x_, x_org_, opt) # merged_val = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) #config = tf.ConfigProto(device_count={'GPU':0}) config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #print([var.name[:-2] for var in t_vars]) loader = restore_from_save(t_vars, sess, opt) except Exception as e: print(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) # if epoch >= 10: # print("Relax embedding ") # opt.fix_emb = False # opt.batch_size = 2 kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] sents_permutated = add_noise(sents, opt) #sents[0] = np.random.permutation(sents[0]) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_batch_org = prepare_data_for_cnn(sents, opt) # Batch L else: x_batch_org = prepare_data_for_rnn(sents, opt) # Batch L if opt.model != 'rnn_rnn': x_batch = prepare_data_for_cnn(sents_permutated, opt) # Batch L else: x_batch = prepare_data_for_rnn(sents_permutated, opt, is_add_GO = False) # Batch L # x_print = sess.run([x_emb],feed_dict={x_: x_train} ) # print x_print # res = sess.run(res_, feed_dict={x_: x_batch, x_org_:x_batch_org}) # pdb.set_trace() # if profile: _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1},options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),run_metadata=run_metadata) else: _, loss = sess.run([train_op, loss_], feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1}) #pdb.set_trace() if uidx % opt.valid_freq == 0: is_train = None valid_index = np.random.choice(len(val), opt.batch_size) val_sents = [val[t] for t in valid_index] val_sents_permutated = add_noise(val_sents, opt) if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': x_val_batch_org = prepare_data_for_cnn(val_sents, opt) else: x_val_batch_org = prepare_data_for_rnn(val_sents, opt) if opt.model != 'rnn_rnn': x_val_batch = prepare_data_for_cnn(val_sents_permutated, opt) else: x_val_batch = prepare_data_for_rnn(val_sents_permutated, opt, is_add_GO=False) loss_val = sess.run(loss_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train }) print("Validation loss %f " % (loss_val)) res = sess.run(res_, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train }) if opt.discrimination: print ("Real Prob %f Fake Prob %f" % (res['prob_r'], res['prob_f'])) print "Val Orig :" + " ".join([ixtoword[x] for x in val_sents[0] if x != 0]) #print "Val Perm :" + " ".join([ixtoword[x] for x in val_sents_permutated[0] if x != 0]) print "Val Recon:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0]) val_set = [prepare_for_bleu(s) for s in val_sents] [bleu2s, bleu3s, bleu4s] = cal_BLEU([prepare_for_bleu(s) for s in res['rec_sents']], {0: val_set}) print 'Val BLEU (2,3,4): ' + ' '.join([str(round(it, 3)) for it in (bleu2s, bleu3s, bleu4s)]) # if opt.model != 'rnn_rnn' and opt.model != 'cnn_rnn': # print "Org Probs:" + " ".join( # [ixtoword[x_val_batch_org[0][i]] + '(' + str(np.round(res['all_p'][i], 1)) + ')' for i in # range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) # print "Gen Probs:" + " ".join( # [ixtoword[res['rec_sents'][0][i]] + '(' + str(np.round(res['gen_p'][i], 1)) + ')' for i in # range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) summary = sess.run(merged, feed_dict={x_: x_val_batch, x_org_: x_val_batch_org, is_train_:is_train }) test_writer.add_summary(summary, uidx) is_train = True def test_input(text): x_input = sent2idx(text, wordtoix, opt) res = sess.run(res_, feed_dict={x_: x_input, x_org_: x_batch_org}) print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0]) if uidx % opt.print_freq == 0: #pdb.set_trace() print("Iteration %d: loss %f " % (uidx, loss)) res = sess.run(res_, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1}) print "Original :" + " ".join([ixtoword[x] for x in sents[0] if x != 0]) #print "Permutated :" + " ".join([ixtoword[x] for x in sents_permutated[0] if x != 0]) if opt.model == 'rnn_rnn' or opt.model == 'cnn_rnn': print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents_feed_y'][0] if x != 0]) print "Reconstructed:" + " ".join([ixtoword[x] for x in res['rec_sents'][0] if x != 0]) # print "Probs:" + " ".join([ixtoword[res['rec_sents'][0][i]] +'(' +str(np.round(res['all_p'][i],2))+')' for i in range(len(res['rec_sents'][0])) if res['rec_sents'][0][i] != 0]) summary = sess.run(merged, feed_dict={x_: x_batch, x_org_: x_batch_org, is_train_:1}) train_writer.add_summary(summary, uidx) # print res['x_rec'][0][0] # print res['x_emb'][0][0] if profile: tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), run_meta=run_metadata, tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY) saver.save(sess, opt.save_path, global_step=epoch)
def main(): #global n_words # Prepare training and testing data opt = COptions(args) opt_t = COptions(args) # opt_t.n_hid = opt.n_z loadpath = (opt.data_dir + "/" + opt.data_name) #if opt.not_philly else '/hdfs/msrlabs/xiag/pt-data/cons/data_cleaned/twitter_small.p' print "loadpath:" + loadpath x = cPickle.load(open(loadpath, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] if opt.test: test_file = opt.data_dir + opt.test_file test = read_test(test_file, wordtoix) # test = [ x for x in test if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])] # train_filtered = [ x for x in train if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])] # val_filtered = [ x for x in val if all([2<len(x[t])<opt.maxlen - 4 for t in range(opt.num_turn)])] # print ("Train: %d => %d" % (len(train), len(train_filtered))) # print ("Val: %d => %d" % (len(val), len(val_filtered))) # train, val = train_filtered, val_filtered # del train_filtered, val_filtered opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params(args) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) # print dict(opt) # if opt.model == 'cnn_rnn': # opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 # opt_t.update_params(args) # print dict(opt_t) #for d in ['/gpu:0', '/gpu:1', '/gpu:2', '/gpu:3']: for d in ['/gpu:0']: with tf.device(d): src_ = [tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context)] tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) z_ = tf.placeholder(tf.float32, shape=[opt_t.batch_size , opt.n_z * (2 if opt.local_feature else 1)]) is_train_ = tf.placeholder(tf.bool, name = 'is_train') res_1_ = get_features(src_, tgt_, is_train_, opt, opt_t) res_2_ = generate_resp(src_, tgt_, z_, is_train_, opt, opt_t) merged = tf.summary.merge_all() #tensorboard --logdir=run1:/tmp/tensorflow/ --port 6006 #writer = tf.train.SummaryWriter(opt.log_path, graph=tf.get_default_graph()) uidx = 0 graph_options=tf.GraphOptions(build_cost_model=1) #config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=tf.GraphOptions(build_cost_model=1)) config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options) # config.gpu_options.per_process_gpu_memory_fraction = 0.70 #config = tf.ConfigProto(device_count={'GPU':0}) #config.gpu_options.allow_growth = True np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: #pdb.set_trace() t_vars = tf.trainable_variables() #t_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) #tf.trainable_variables() # if opt.load_from_pretrain: # d_vars = [var for var in t_vars if var.name.startswith('d_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # g_vars = [var for var in t_vars if var.name.startswith('g_')] # l_vars = [var for var in t_vars if var.name.startswith('l_')] # #restore_from_save(g_vars, sess, opt, prefix = 'g_', load_path=opt.restore_dir + "/save/generator2") # restore_from_save(d_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.global_d) # if opt.local_feature: # restore_from_save(l_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.local_d) # else: loader = restore_from_save(t_vars, sess, opt, load_path = opt.save_path) except Exception as e: print 'Error: '+str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d , loss_g = 0, 0 if opt.test: iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 res_all = [] val_tgt_all =[] for i in range(iter_num): test_index = range(i * opt.batch_size,(i+1) * opt.batch_size) sents = [test[t%len(test)] for t in test_index] for idx in range(opt.n_context,opt.num_turn): src = [[sents[i][idx-turn] for i in range(opt.batch_size)] for turn in range(opt.n_context,0,-1)] tgt = [sents[i][idx] for i in range(opt.batch_size)] val_tgt_all.extend(tgt) if opt.feed_generated and idx!= opt.n_context: src[-1] = [[x for x in p if x!=0] for p in res_all[-opt.batch_size:]] x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) # do not use False res_1 = sess.run(res_1_, feed_dict=feed) z_all = np.array(res_1['z']) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, z_: z_all, is_train_: 0}) # do not use False res_2 = sess.run(res_2_, feed_dict=feed) res_all.extend(res_2['syn_sent']) # bp() val_tgt_all = reshaping(val_tgt_all, opt) res_all = reshaping(res_all, opt) save_path = opt.log_path + '.resp.txt' if os.path.exists(save_path): os.remove(save_path) for idx in range(len(test)*(opt.num_turn-opt.n_context)): with open(save_path, "a") as resp_f: resp_f.write(u' '.join([ixtoword[x] for x in res_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') ) print ("save to:" + save_path) if opt.verbose: save_path = opt.log_path + '.tgt.txt' if os.path.exists(save_path): os.remove(save_path) for idx in range(len(test)*(opt.num_turn-opt.n_context)): with open(save_path, "a") as tgt_f: tgt_f.write(u' '.join([ixtoword[x] for x in val_tgt_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') ) print ("save to:" + save_path) val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus) etp_score, dist_score = cal_entropy(gen) # print save_path print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)]) # print 'Val Rouge: ' + ' '.join([str(round(it,3)) for it in (rouge1,rouge2,rouge3,rouge4)]) print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])]) print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])]) # print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])]) print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3)) if opt.embedding_score: with open("../../ssd0/consistent_dialog/data/GoogleNews-vectors-negative300.bin.p", 'rb') as pfile: embedding = cPickle.load(pfile) rel_score = cal_relevance(gen, val_set, embedding) print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])]) if not opt.global_feature or opt.bit == None: exit(0) if opt.test: iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 for int_idx in range(opt.int_num): res_all = [] z1,z2,z3 = [],[],[] val_tgt_all =[] for i in range(iter_num): test_index = range(i * opt.batch_size,(i+1) * opt.batch_size) sents = [test[t%len(test)] for t in test_index] for idx in range(opt.n_context,opt.num_turn): src = [[sents[i][idx-turn] for i in range(opt.batch_size)] for turn in range(opt.n_context,0,-1)] tgt = [sents[i][idx] for i in range(opt.batch_size)] val_tgt_all.extend(tgt) if opt.feed_generated and idx!= opt.n_context: src[-1] = [[x for x in p if x!=0] for p in res_all[-opt.batch_size:]] x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) # do not use False res_1 = sess.run(res_1_, feed_dict=feed) z_all = np.array(res_1['z']) z_all[:,opt.bit] = np.array([1.0/np.float(opt.int_num-1) * int_idx for _ in range(opt.batch_size)]) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, z_: z_all, is_train_: 0}) # do not use False res_2 = sess.run(res_2_, feed_dict=feed) res_all.extend(res_2['syn_sent']) z1.extend(res_1['z']) z2.extend(z_all) z3.extend(res_2['z_hat']) # bp() val_tgt_all = reshaping(val_tgt_all, opt) res_all = reshaping(res_all, opt) z1 = reshaping(z1, opt) z2 = reshaping(z2, opt) z3 = reshaping(z3, opt) save_path = opt.log_path + 'bit' + str(opt.bit) + '.'+ str(1.0/np.float(opt.int_num-1) * int_idx) +'.int.txt' if os.path.exists(save_path): os.remove(save_path) for idx in range(len(test)*(opt.num_turn-opt.n_context)): with open(save_path, "a") as resp_f: resp_f.write(u' '.join([ixtoword[x] for x in res_all[idx] if x != 0 and x != 2]).encode('utf-8').strip() + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t') ) print ("save to:" + save_path) save_path_z = opt.log_path + 'bit' + str(opt.bit) + '.'+ str(1.0/np.float(opt.int_num-1) * int_idx) +'.z.txt' if os.path.exists(save_path_z): os.remove(save_path_z) for idx in range(len(test)*(opt.num_turn-opt.n_context)): with open(save_path_z, "a") as myfile: #ary = np.array([z1[idx][opt.bit], z2[idx][opt.bit], z3[idx][opt.bit]]) #myfile.write(np.array2string(ary, formatter={'float_kind':lambda x: "%.2f" % x}) + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t')) myfile.write(str(z3[idx][opt.bit]) + ('\n' if idx%(opt.num_turn-opt.n_context) == opt.num_turn-opt.n_context-1 else '\t')) val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s,bleu2s,bleu3s,bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus = opt.is_corpus) etp_score, dist_score = cal_entropy(gen) print save_path print 'Val BLEU: ' + ' '.join([str(round(it,3)) for it in (bleu1s,bleu2s,bleu3s,bleu4s)]) # print 'Val Rouge: ' + ' '.join([str(round(it,3)) for it in (rouge1,rouge2,rouge3,rouge4)]) print 'Val Entropy: ' + ' '.join([str(round(it,3)) for it in (etp_score[0],etp_score[1],etp_score[2],etp_score[3])]) print 'Val Diversity: ' + ' '.join([str(round(it,3)) for it in (dist_score[0],dist_score[1],dist_score[2],dist_score[3])]) # print 'Val Relevance(G,A,E): ' + ' '.join([str(round(it,3)) for it in (rel_score[0],rel_score[1],rel_score[2])]) print 'Val Avg. length: ' + str(round(np.mean([len([y for y in x if y!=0]) for x in res_all]),3))
def main(): #global n_words # Prepare training and testing data opt = COptions(args) opt_t = COptions(args) loadpath = (opt.data_dir + "/" + opt.data_name) print "loadpath:" + loadpath x = cPickle.load(open(loadpath, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] if opt.test: test_file = opt.data_dir + "/newdata2/test.txt" test = read_test(test_file, wordtoix) test = [ x for x in test if all( [2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)]) ] train_filtered = [ x for x in train if all([2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)]) ] val_filtered = [ x for x in val if all([2 < len(x[t]) < opt.maxlen - 4 for t in range(opt.num_turn)]) ] print("Train: %d => %d" % (len(train), len(train_filtered))) print("Val: %d => %d" % (len(val), len(val_filtered))) train, val = train_filtered, val_filtered del train_filtered, val_filtered opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params(args) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) for d in ['/gpu:0']: with tf.device(d): src_ = [ tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context) ] tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) is_train_ = tf.placeholder(tf.bool, name='is_train') res_, gan_cost_g_, train_op_g = conditional_s2s( src_, tgt_, is_train_, opt, opt_t) merged = tf.summary.merge_all() uidx = 0 graph_options = tf.GraphOptions(build_cost_model=1) config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, graph_options=graph_options) config.gpu_options.per_process_gpu_memory_fraction = 0.90 np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config=config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() if opt.load_from_pretrain: d_vars = [ var for var in t_vars if var.name.startswith('d_') ] g_vars = [ var for var in t_vars if var.name.startswith('g_') ] l_vars = [ var for var in t_vars if var.name.startswith('l_') ] restore_from_save(d_vars, sess, opt, load_path=opt.restore_dir + "/save/" + opt.global_d) if opt.local_feature: restore_from_save(l_vars, sess, opt, load_path=opt.restore_dir + "/save/" + opt.local_d) else: loader = restore_from_save(t_vars, sess, opt, load_path=opt.save_path) except Exception as e: print 'Error: ' + str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d, loss_g = 0, 0 if opt.test: iter_num = np.int(np.floor(len(test) / opt.batch_size)) + 1 res_all = [] for i in range(iter_num): test_index = range(i * opt.batch_size, (i + 1) * opt.batch_size) sents = [val[t] for t in test_index] for idx in range(opt.n_context, opt.num_turn): src = [[ sents[i][idx - turn] for i in range(opt.batch_size) ] for turn in range(opt.n_context, 0, -1)] tgt = [sents[i][idx] for i in range(opt.batch_size)] x_batch = [ prepare_data_for_cnn(src_i, opt) for src_i in src ] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, { tgt_: y_batch, is_train_: 0 }) # do not use False res = sess.run(res_, feed_dict=feed) res_all.extend(res['syn_sent']) # bp() res_all = reshaping(res_all, opt) for idx in range(len(test) * (opt.num_turn - opt.n_context)): with open(opt.log_path + '.resp.txt', "a") as resp_f: resp_f.write(u' '.join([ ixtoword[x] for x in res_all[idx] if x != 0 and x != 2 ]).encode('utf-8').strip() + ( '\n' if idx % (opt.num_turn - opt.n_context) == 0 else '\t')) print("save to:" + opt.log_path + '.resp.txt') exit(0) for epoch in range(opt.max_epochs): print("Starting epoch %d" % epoch) kf = get_minibatches_idx(len(train), opt.batch_size, shuffle=True) for _, train_index in kf: uidx += 1 sents = [train[t] for t in train_index] for idx in range(opt.n_context, opt.num_turn): src = [[ sents[i][idx - turn] for i in range(opt.batch_size) ] for turn in range(opt.n_context, 0, -1)] tgt = [sents[i][idx] for i in range(opt.batch_size)] x_batch = [ prepare_data_for_cnn(src_i, opt) for src_i in src ] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, { tgt_: y_batch, is_train_: 1 }) _, loss_g = sess.run([train_op_g, gan_cost_g_], feed_dict=feed) if uidx % opt.print_freq == 0: print("Iteration %d: loss G %f" % (uidx, loss_g)) res = sess.run(res_, feed_dict=feed) if opt.global_feature: print "z loss: " + str(res['z_loss']) if "nn" in opt.agg_model: print "z pred_loss: " + str(res['z_loss_pred']) print "Source:" + u' '.join( [ixtoword[x] for s in x_batch for x in s[0] if x != 0]).encode('utf-8').strip() print "Target:" + u' '.join([ ixtoword[x] for x in y_batch[0] if x != 0 ]).encode('utf-8').strip() print "Generated:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ]).encode('utf-8').strip() print "" sys.stdout.flush() summary = sess.run(merged, feed_dict=feed) train_writer.add_summary(summary, uidx) if uidx % opt.valid_freq == 1: VALID_SIZE = 4096 valid_multiplier = np.int( np.floor(VALID_SIZE / opt.batch_size)) res_all, val_tgt_all, loss_val_g_all = [], [], [] if opt.global_feature: z_loss_all = [] for val_step in range(valid_multiplier): valid_index = np.random.choice(len(val), opt.batch_size) sents = [val[t] for t in valid_index] for idx in range(opt.n_context, opt.num_turn): src = [[ sents[i][idx - turn] for i in range(opt.batch_size) ] for turn in range(opt.n_context, 0, -1)] tgt = [ sents[i][idx] for i in range(opt.batch_size) ] val_tgt_all.extend(tgt) x_batch = [ prepare_data_for_cnn(src_i, opt) for src_i in src ] # Batch L y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO=False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, { tgt_: y_batch, is_train_: 0 }) # do not use False loss_val_g = sess.run([gan_cost_g_], feed_dict=feed) loss_val_g_all.append(loss_val_g) res = sess.run(res_, feed_dict=feed) res_all.extend(res['syn_sent']) if opt.global_feature: z_loss_all.append(res['z_loss']) print("Validation: loss G %f " % (np.mean(loss_val_g_all))) if opt.global_feature: print "z loss: " + str(np.mean(z_loss_all)) print "Val Source:" + u' '.join( [ixtoword[x] for s in x_batch for x in s[0] if x != 0]).encode('utf-8').strip() print "Val Target:" + u' '.join([ ixtoword[x] for x in y_batch[0] if x != 0 ]).encode('utf-8').strip() print "Val Generated:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ]).encode('utf-8').strip() print "" if opt.global_feature: with open(opt.log_path + '.z.txt', "a") as myfile: myfile.write("Iteration" + str(uidx) + "\n") myfile.write("z_loss %f" % (np.mean(z_loss_all)) + "\n") myfile.write("Val Source:" + u' '.join([ ixtoword[x] for s in x_batch for x in s[0] if x != 0 ]).encode('utf-8').strip() + "\n") myfile.write("Val Target:" + u' '.join( [ixtoword[x] for x in y_batch[0] if x != 0]).encode('utf-8').strip() + "\n") myfile.write("Val Generated:" + u' '.join([ ixtoword[x] for x in res['syn_sent'][0] if x != 0 ]).encode('utf-8').strip() + "\n") myfile.write("Z_input, Z_recon, Z_tgt") myfile.write( np.array2string(res['z'][0], formatter={ 'float_kind': lambda x: "%.2f" % x }) + "\n") myfile.write( np.array2string(res['z_hat'][0], formatter={ 'float_kind': lambda x: "%.2f" % x }) + "\n\n") myfile.write( np.array2string(res['z_tgt'][0], formatter={ 'float_kind': lambda x: "%.2f" % x }) + "\n\n") val_set = [prepare_for_bleu(s) for s in val_tgt_all] gen = [prepare_for_bleu(s) for s in res_all] [bleu1s, bleu2s, bleu3s, bleu4s] = cal_BLEU_4(gen, {0: val_set}, is_corpus=opt.is_corpus) etp_score, dist_score = cal_entropy(gen) print 'Val BLEU: ' + ' '.join([ str(round(it, 3)) for it in (bleu1s, bleu2s, bleu3s, bleu4s) ]) print 'Val Entropy: ' + ' '.join([ str(round(it, 3)) for it in (etp_score[0], etp_score[1], etp_score[2], etp_score[3]) ]) print 'Val Diversity: ' + ' '.join([ str(round(it, 3)) for it in (dist_score[0], dist_score[1], dist_score[2], dist_score[3]) ]) print 'Val Avg. length: ' + str( round( np.mean([ len([y for y in x if y != 0]) for x in res_all ]), 3)) print "" summary = sess.run(merged, feed_dict=feed) summary2 = tf.Summary(value=[ tf.Summary.Value(tag="bleu-2", simple_value=bleu2s), tf.Summary.Value(tag="etp-4", simple_value=etp_score[3]) ]) test_writer.add_summary(summary, uidx) test_writer.add_summary(summary2, uidx) if uidx % opt.save_freq == 0: saver.save(sess, opt.save_path)
def main(): opt = COptions(args) opt_t = COptions(args) loadpath = (opt.data_dir + "/" + opt.data_name) print "loadpath:" + loadpath x = cPickle.load(open(loadpath, "rb")) train, val, test = x[0], x[1], x[2] wordtoix, ixtoword = x[3], x[4] if opt.test: test_file = opt.data_dir + opt.test_file test = read_test(test_file, wordtoix) opt.n_words = len(ixtoword) opt_t.n_words = len(ixtoword) opt_t.maxlen = opt_t.maxlen - opt_t.filter_shape + 1 opt_t.update_params(args) print datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") print dict(opt) print('Total words: %d' % opt.n_words) for d in ['/gpu:0']: with tf.device(d): src_ = [tf.placeholder(tf.int32, shape=[opt.batch_size, opt.sent_len]) for _ in range(opt.n_context)] tgt_ = tf.placeholder(tf.int32, shape=[opt_t.batch_size, opt_t.sent_len]) is_train_ = tf.placeholder(tf.bool, name = 'is_train') res_1_ = get_features(src_, tgt_, is_train_, opt, opt_t) merged = tf.summary.merge_all() uidx = 0 graph_options=tf.GraphOptions(build_cost_model=1) config = tf.ConfigProto(log_device_placement = False, allow_soft_placement=True, graph_options=graph_options) np.set_printoptions(precision=3) np.set_printoptions(threshold=np.inf) saver = tf.train.Saver() run_metadata = tf.RunMetadata() with tf.Session(config = config) as sess: train_writer = tf.summary.FileWriter(opt.log_path + '/train', sess.graph) test_writer = tf.summary.FileWriter(opt.log_path + '/test', sess.graph) sess.run(tf.global_variables_initializer()) if opt.restore: try: t_vars = tf.trainable_variables() if opt.load_from_pretrain: d_vars = [var for var in t_vars if var.name.startswith('d_')] l_vars = [var for var in t_vars if var.name.startswith('l_')] restore_from_save(d_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.global_d) if opt.local_feature: restore_from_save(l_vars, sess, opt, load_path = opt.restore_dir + "/save/" + opt.local_d) else: loader = restore_from_save(t_vars, sess, opt, load_path = opt.save_path) except Exception as e: print 'Error: '+str(e) print("No saving session, using random initialization") sess.run(tf.global_variables_initializer()) loss_d , loss_g = 0, 0 if opt.test: iter_num = np.int(np.floor(len(test)/opt.batch_size))+1 z_all, z_all_l = [], [] for i in range(iter_num): test_index = range(i * opt.batch_size,(i+1) * opt.batch_size) sents = [test[t%len(test)] for t in test_index] src = [[sents[i][0] for i in range(opt.batch_size)]] tgt = [sents[i][0] for i in range(opt.batch_size)] x_batch = [prepare_data_for_cnn(src_i, opt) for src_i in src] print "Source:" + u' '.join([ixtoword[x] for s in x_batch for x in s[0] if x != 0]).encode('utf-8').strip() y_batch = prepare_data_for_rnn(tgt, opt_t, is_add_GO = False) feed = merge_two_dicts( {i: d for i, d in zip(src_, x_batch)}, {tgt_: y_batch, is_train_: 0}) res_1 = sess.run(res_1_, feed_dict=feed) z_all.extend(res_1['z']) z_all_l.extend(res_1['z_l']) save_path_z = opt.log_path + '.global.z.txt' print save_path_z if os.path.exists(save_path_z): os.remove(save_path_z) with open(save_path_z, "a") as myfile: for line in z_all[:len(test)]: for z_it in line: myfile.write(str(z_it) + '\t') myfile.write('\n') save_path_z = opt.log_path + '.local.z.txt' print save_path_z if os.path.exists(save_path_z): os.remove(save_path_z) with open(save_path_z, "a") as myfile: for line in z_all_l[:len(test)]: for z_it in line: myfile.write(str(z_it) + '\t') myfile.write('\n')