def _main(_): # Data train_data = tx.data.MultiAlignedData(config.train_data) val_data = tx.data.MultiAlignedData(config.val_data) test_data = tx.data.MultiAlignedData(config.test_data) if config.manual: manual_data = tx.data.MultiAlignedData(config.manual_data) vocab = train_data.vocab(0) # Each training batch is used twice: once for updating the generator and # once for updating the discriminator. Feedable data iterator is used for # such case. if config.manual: iterator = tx.data.FeedableDataIterator({ 'train_g': train_data, 'train_d': train_data, 'train_z': train_data, 'val': val_data, 'test': test_data, 'manual': manual_data }) else: iterator = tx.data.FeedableDataIterator({ 'train_g': train_data, 'train_d': train_data, 'train_z': train_data, 'val': val_data, 'test': test_data }) batch = iterator.get_next() # Model gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma') lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g') lambda_z = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_z') lambda_z1 = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_z1') lambda_z2 = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_z2') lambda_ae = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_ae') model = CtrlGenModel(batch, vocab, gamma, lambda_g, lambda_z, lambda_z1, lambda_z2, lambda_ae, config.model) def _train_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_, lambda_ae_, epoch, verbose=True): avg_meters_d = tx.utils.AverageRecorder(size=10) avg_meters_g = tx.utils.AverageRecorder(size=10) avg_meters_z = tx.utils.AverageRecorder(size=10) step = 0 while True: try: step += 1 feed_dict = { iterator.handle: iterator.get_handle(sess, 'train_d'), gamma: gamma_, lambda_g: lambda_g_, lambda_z: lambda_z_, lambda_z1: lambda_z1_, lambda_z2: lambda_z2_, lambda_ae: lambda_ae_ } vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict) avg_meters_d.add(vals_d) feed_dict = { iterator.handle: iterator.get_handle(sess, 'train_g'), gamma: gamma_, lambda_g: lambda_g_, lambda_z: lambda_z_, lambda_z1: lambda_z1_, lambda_z2: lambda_z2_, lambda_ae: lambda_ae_ } vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict) avg_meters_g.add(vals_g) feed_dict = { iterator.handle: iterator.get_handle(sess, 'train_z'), gamma: gamma_, lambda_g: lambda_g_, lambda_z: lambda_z_, lambda_z1: lambda_z1_, lambda_z2: lambda_z2_, lambda_ae: lambda_ae_ } vals_z = sess.run(model.fetches_train_z, feed_dict=feed_dict) avg_meters_z.add(vals_z) if verbose and (step == 1 or step % config.display == 0): print('epoch: {}, step: {}, {}'.format( epoch, step, avg_meters_d.to_str(4))) print('epoch: {}, step: {}, {}'.format( epoch, step, avg_meters_z.to_str(4))) print('epoch: {}, step: {}, {}'.format( epoch, step, avg_meters_g.to_str(4))) if verbose and step % config.display_eval == 0: iterator.restart_dataset(sess, 'val') _eval_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_, lambda_ae_, epoch) except tf.errors.OutOfRangeError: print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4))) print('epoch: {}, {}'.format(epoch, avg_meters_z.to_str(4))) print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4))) break def _eval_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_, lambda_ae_, epoch, val_or_test='val', plot_z=False, plot_max_count=1000, spam=False, repetitions=False, write_text=True, write_labels=False): avg_meters = tx.utils.AverageRecorder() if plot_z: z_vectors = [] labels = [] tsne = TSNE(n_components=2) while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, val_or_test), gamma: gamma_, lambda_g: lambda_g_, lambda_z: lambda_z_, lambda_z1: lambda_z1_, lambda_z2: lambda_z2_, lambda_ae: lambda_ae_, tx.context.global_mode(): tf.estimator.ModeKeys.EVAL } vals = sess.run(model.fetches_eval, feed_dict=feed_dict) batch_size = vals.pop('batch_size') # Computes BLEU samples = tx.utils.dict_pop(vals, list(model.samples.keys())) hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab) refs = tx.utils.map_ids_to_strs(samples['original'], vocab) refs = np.expand_dims(refs, axis=1) bleu = tx.evals.corpus_bleu_moses(refs, hyps) vals['bleu'] = bleu if spam or repetitions: target_labels = samples['labels_target'] predicted_labels = samples['labels_predicted'] results = [(r, h, t, p) for r, h, t, p in zip( refs, hyps, target_labels, predicted_labels)] # Computes repetitions if repetitions: count_equal_strings = 0 remain_samples_e = [] for r, h, t, p in results: if r == h: count_equal_strings += 1 else: remain_samples_e.append((r, h, t, p)) vals['equal'] = count_equal_strings / len(hyps) # Computes spam if spam: count_spam = 0 remain_samples_s = [] for r, h, t, p in results: words = h.split() if len(words) > 2 and words[-1] == words[-2]: count_spam += 1 elif len(words) > 4 and words[-1] == words[ -3] and words[-2] == words[-4]: count_spam += 1 else: remain_samples_s.append((r, h, t, p)) vals['spam'] = count_spam / len(hyps) if repetitions and spam: remain_samples = [ semple for semple in remain_samples_e if semple in remain_samples_s ] remain_samples = list(remain_samples) elif not repetitions and spam: remain_samples = remain_samples_s elif repetitions and not spam: remain_samples = remain_samples_e if repetitions and spam: refs_remain = [r for r, h, t, p in remain_samples] hyps_remain = [h for r, h, t, p in remain_samples] bleu_remain = tx.evals.corpus_bleu_moses( refs_remain, hyps_remain) vals['bleu_remain'] = bleu_remain if len(remain_samples) != 0: true_labels = 0 for _, _, t, p in remain_samples: if t == p: true_labels += 1 vals['acc_remain'] = true_labels / len(remain_samples) else: vals['acc_remain'] = 0. avg_meters.add(vals, weight=batch_size) if plot_z: z_vectors += samples['z_vector'].tolist() labels += samples['labels_source'].tolist() # Writes samples if write_text: tx.utils.write_paired_text( refs.squeeze(), hyps, os.path.join(config.sample_path, 'text_{}.{}'.format(val_or_test, epoch)), append=True, mode='v') # Writes labels samples if write_labels: tx.utils.write_paired_text( [str(l) for l in samples['labels_target'].tolist()], [str(l) for l in samples['labels_predicted'].tolist()], os.path.join(config.sample_path, 'labels_{}.{}'.format(val_or_test, epoch)), append=True, mode='v') except tf.errors.OutOfRangeError: print('epoch: {}, {}: {}'.format( epoch, val_or_test, avg_meters.to_str(precision=4))) break if plot_z: if plot_max_count == 0: z_vectors = z_vectors labels = labels else: z_vectors = z_vectors[:plot_max_count] labels = labels[:plot_max_count] tsne_result = tsne.fit_transform(np.array(z_vectors)) x_data = tsne_result[:, 0] y_data = tsne_result[:, 1] plt.scatter(x_data, y_data, c=np.array(labels), s=1, cmap=plt.cm.get_cmap('jet', 2)) plt.clim(0.0, 1.0) if not os.path.exists('./images'): os.makedirs('./images') plt.savefig('./images/{}_{}.png'.format(val_or_test, epoch)) plt.clf() return avg_meters.avg() tf.gfile.MakeDirs(config.sample_path) tf.gfile.MakeDirs(config.checkpoint_path) # Runs the logics with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) saver = tf.train.Saver(max_to_keep=None) if config.restore: print('Restore from: {}'.format(config.restore)) saver.restore(sess, config.restore) iterator.initialize_dataset(sess) gamma_ = 1. lambda_g_ = 0. lambda_z_ = 0. lambda_ae_ = 1. lambda_z1_ = config.lambda_z1 lambda_z2_ = config.lambda_z2 for epoch in range(1, config.max_nepochs + 1): if epoch > config.pretrain_ae_nepochs: # Anneals the gumbel-softmax temperature gamma_ = max(0.001, gamma_ * config.gamma_decay) lambda_g_ = config.lambda_g lambda_z_ = config.lambda_z if epoch > config.chage_lambda_ae_epoch: lambda_ae_ = lambda_ae_ - config.change_lambda_ae print( 'gamma: {}, lambda_g: {}, lambda_z: {}, lambda_z1: {}, lambda_z2: {}, lambda_ae: {}' .format(gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_, lambda_ae_)) # Train iterator.restart_dataset(sess, ['train_g', 'train_d', 'train_z']) _train_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_, lambda_ae_, epoch) # Val iterator.restart_dataset(sess, 'val') _eval_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_, lambda_ae_, epoch, 'val', plot_z=config.plot_z, plot_max_count=config.plot_max_count, spam=config.spam, repetitions=config.repetitions, write_text=config.write_text, write_labels=config.write_labels) saver.save(sess, os.path.join(config.checkpoint_path, 'ckpt'), epoch) # Test iterator.restart_dataset(sess, 'test') _eval_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_, lambda_ae_, epoch, 'test', plot_z=config.plot_z, plot_max_count=config.plot_max_count, spam=config.spam, repetitions=config.repetitions, write_text=config.write_text, write_labels=config.write_labels) if config.manual: iterator.restart_dataset(sess, 'manual') _eval_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_, lambda_ae_, epoch, 'manual', plot_z=config.plot_z, plot_max_count=config.plot_max_count, spam=config.spam, repetitions=config.repetitions, write_text=config.write_text, write_labels=config.write_labels)
def _main(_): # Data train_data = tx.data.MultiAlignedData(config.train_data) test_data = tx.data.MultiAlignedData(config.test_data) vocab = train_data.vocab(0) # Each training batch is used twice: once for updating the generator and # once for updating the discriminator. Feedable data iterator is used for # such case. iterator = tx.data.FeedableDataIterator( {'test': test_data}) batch = iterator.get_next() # Model gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma') lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g') model = CtrlGenModel(batch, vocab, gamma, lambda_g, config.model) def _eval_epoch(sess, gamma_, lambda_g_, val_or_test='test'): avg_meters = tx.utils.AverageRecorder() while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, val_or_test), gamma: gamma_, lambda_g: lambda_g_, tx.context.global_mode(): tf.estimator.ModeKeys.EVAL } vals = sess.run(model.fetches_eval, feed_dict=feed_dict) batch_size = vals.pop('batch_size') # Computes BLEU samples = tx.utils.dict_pop(vals, list(model.samples.keys())) hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab) refs = tx.utils.map_ids_to_strs(samples['original'], vocab) refs = np.expand_dims(refs, axis=1) bleu = tx.evals.corpus_bleu_moses(refs, hyps) vals['bleu'] = bleu avg_meters.add(vals, weight=batch_size) # Writes samples tx.utils.write_paired_text( refs.squeeze(),hyps, os.path.join(config.sample_path, 'result'), append=True, mode='v') except tf.errors.OutOfRangeError: print('{}: {}'.format( val_or_test, avg_meters.to_str(precision=4))) break return avg_meters.avg() tf.gfile.MakeDirs(config.sample_path) tf.gfile.MakeDirs(config.checkpoint_path) # Runs the logics with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) saver = tf.train.Saver(max_to_keep=None) if config.restore: print('Restore from: {}'.format(config.restore)) saver.restore(sess, config.restore) iterator.initialize_dataset(sess) gamma_ = 1. lambda_g_ = 0. # Anneals the gumbel-softmax temperature gamma_ = max(0.001, 1.* config.gamma_decay) lambda_g_ = config.lambda_g #print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_)) # Test iterator.restart_dataset(sess, 'test') _eval_epoch(sess, gamma_, lambda_g_, 'test')
def _main(_): # Data train_autoencoder = tx.data.MultiAlignedData(config.train_autoencoder) dev_autoencoder = tx.data.MultiAlignedData(config.dev_autoencoder) test_autoencoder = tx.data.MultiAlignedData(config.test_autoencoder) train_discriminator = tx.data.MultiAlignedData(config.train_discriminator) dev_discriminator = tx.data.MultiAlignedData(config.dev_discriminator) test_discriminator = tx.data.MultiAlignedData(config.test_discriminator) train_defender = tx.data.MultiAlignedData(config.train_defender) test_defender = tx.data.MultiAlignedData(config.test_defender) vocab = train_discriminator.vocab(0) iterator = tx.data.FeedableDataIterator({ 'train_autoencoder': train_autoencoder, 'dev_autoencder': dev_autoencoder, 'test_autoencoder': test_autoencoder, 'train_discriminator': train_discriminator, 'dev_discriminator': dev_discriminator, 'test_discriminator': test_discriminator, 'train_defender': train_defender, 'test_defender': test_defender, }) batch = iterator.get_next() # Model gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma') lambda_D = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g') lambda_ae_ = 1.0 model = CtrlGenModel(batch, vocab, lambda_ae_, gamma, lambda_D, config.model) def autoencoder(sess, lambda_ae_, gamma_, lambda_D_, epoch, mode, verbose=True): avg_meters_g = tx.utils.AverageRecorder(size=10) step = 0 if mode == "train": dataset = "train_autoencoder" while True: try: step += 1 feed_dict = { iterator.handle: iterator.get_handle(sess, dataset), gamma: gamma_, lambda_D: lambda_D_, } vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict) loss_g_ae_summary = vals_g.pop("loss_g_ae_summary") loss_g_clas_summary = vals_g.pop("loss_g_clas_summary") avg_meters_g.add(vals_g) if verbose and (step == 1 or step % config.display == 0): print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) if verbose and step % config.display_eval == 0: iterator.restart_dataset(sess, 'dev_autoencoder') _eval_epoch(sess, lambda_ae_, gamma_, lambda_ae_, epoch) except tf.errors.OutOfRangeError: print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4))) break else: dataset = "test_autoencoder" while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, dataset), gamma: gamma_, lambda_D: lambda_D_, tx.context.global_mode(): tf.estimator.ModeKeys.EVAL } vals = sess.run(model.fetches_eval, feed_dict=feed_dict) samples = tx.utils.dict_pop(vals, list(model.samples.keys())) hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab) refs = tx.utils.map_ids_to_strs(samples['original'], vocab) refs = np.expand_dims(refs, axis=1) avg_meters_g.add(vals) # Writes samples tx.utils.write_paired_text(refs.squeeze(), hyps, os.path.join( config.sample_path, 'val.%d' % epoch), append=True, mode='v') except tf.errors.OutOfRangeError: print('{}: {}'.format("test_autoencoder_only", avg_meters_g.to_str(precision=4))) break def discriminator(sess, lambda_ae_, gamma_, lambda_D_, epoch, mode, verbose=True): avg_meters_d = tx.utils.AverageRecorder(size=10) y_true = [] y_pred = [] y_prob = [] sentences = [] step = 0 if mode == "train": dataset = "train_discriminator" while True: try: step += 1 feed_dict = { iterator.handle: iterator.get_handle(sess, dataset), gamma: gamma_, lambda_D: lambda_D_, } vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict) y_pred.extend(vals_d.pop("y_pred").tolist()) y_true.extend(vals_d.pop("y_true").tolist()) y_prob.extend(vals_d.pop("y_prob").tolist()) sentences.extend(vals_d.pop("sentences").tolist()) avg_meters_d.add(vals_d) # if verbose and (step == 1 or step % config.display == 0): if verbose and step % 40 == 0: print('step: {}, {}'.format(step, avg_meters_d.to_str(4))) except tf.errors.OutOfRangeError: iterator.restart_dataset(sess, 'dev_discriminator') _, _, _, _, val_acc = _eval_discriminator( sess, lambda_ae_, gamma_, lambda_D_, epoch, 'dev_discriminator') return val_acc if mode == 'test': dataset = "test_discriminator" iterator.restart_dataset(sess, dataset) y_pred, y_true, y_prob, sentences, _ = _eval_discriminator( sess, lambda_ae_, gamma_, lambda_D_, epoch, dataset) assert (len(y_pred) == len(y_true) == len(y_prob) == len(sentences)) # tp=0 # tn=0 # fp=0 # acc=0 # for sent,label,pred,prob in zip(sentences,y_true,y_pred,y_prob): # if pred==1 and label==1: # tp+=1.0/len(y_true) # if pred==0 and label==0: # tn+=1.0/len(y_true) # if pred==1 and label==0: # fp+=1.0/len(y_true) # if pred==label: # acc+=1.0/len(y_true) # print('true_positives:{}'.format(tp)) # print('true_negatives:{}'.format(tn)) # print('false_positives:{}'.format(fp)) # print('accuracy:{}'.format(acc)) # with open('prob_vocab.txt', 'w') as file: # for word, prob_values in zip(sentences,y_prob): # file.write(word) # file.write('\t') # file.write(str(prob_values)) # file.write('\n') # txt=open('rand_sent_from_vocab_Discriminator_label.txt','w') # with open('rand_sent_from_vocab_Discriminator.txt', 'w') as file: # for sentence, pred in zip(sentences, y_pred): # file.write(sentence+'\n') # txt.write(str(pred)+'\n') txt = open( DATA_DIR + 'rand_x_sent_from_vocab_Discriminator_neg_confirmed.txt', 'w') with open( DATA_DIR + 'rand_x_sent_from_vocab_Discriminator_neg_confirmed.txt', 'w') as file: for sentence, pred, label in zip(sentences, y_pred, y_true): if pred == 0 and label == 0: file.write(sentence + '\n') txt.write(str(pred) + '\n') def defender(sess, lambda_ae_, gamma_, lambda_D_, epoch, mode, verbose=True): avg_meters_g = tx.utils.AverageRecorder(size=10) step = 0 if mode == "train": dataset = "train_defender" while True: try: step += 1 feed_dict = { iterator.handle: iterator.get_handle(sess, dataset), gamma: gamma_, lambda_D: lambda_D_, } vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict) loss_g_ae_summary = vals_g.pop("loss_g_ae_summary") loss_g_clas_summary = vals_g.pop("loss_g_clas_summary") avg_meters_g.add(vals_g) if verbose and (step == 1 or step % config.display == 0): print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) except tf.errors.OutOfRangeError: print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4))) break else: dataset = "test_defender" while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, dataset), gamma: gamma_, lambda_D: lambda_D_, tx.context.global_mode(): tf.estimator.ModeKeys.EVAL } vals = sess.run(model.fetches_eval, feed_dict=feed_dict) samples = tx.utils.dict_pop(vals, list(model.samples.keys())) hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab) refs = tx.utils.map_ids_to_strs(samples['original'], vocab) refs = np.expand_dims(refs, axis=1) avg_meters_g.add(vals) # Writes samples tx.utils.write_paired_text(refs.squeeze(), hyps, os.path.join( config.sample_path, 'defender_val.%d' % epoch), append=True, mode='v') except tf.errors.OutOfRangeError: print('{}: {}'.format("test_defender", avg_meters_g.to_str(precision=4))) break def _eval_discriminator(sess, lambda_ae_, gamma_, lambda_D_, epoch, dataset): avg_meters_d = tx.utils.AverageRecorder() y_true = [] y_pred = [] y_prob = [] sentences = [] while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, dataset), gamma: gamma_, lambda_D: lambda_D_, } vals_d = sess.run(model.fetches_dev_test_d, feed_dict=feed_dict) y_pred.extend(vals_d.pop("y_pred").tolist()) y_true.extend(vals_d.pop("y_true").tolist()) y_prob.extend(vals_d.pop("y_prob").tolist()) sentence = vals_d.pop("sentences").tolist() sentences.extend(tx.utils.map_ids_to_strs(sentence, vocab)) batch_size = vals_d.pop('batch_size') avg_meters_d.add(vals_d, weight=batch_size) except tf.errors.OutOfRangeError: acc = avg_meters_d.avg()['accu_d'] print('{}: {}'.format(dataset, avg_meters_d.to_str(precision=4))) break return y_pred, y_true, y_prob, sentences, acc tf.gfile.MakeDirs(config.sample_path) tf.gfile.MakeDirs(config.checkpoint_path) # Runs the logics with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) saver = tf.train.Saver(max_to_keep=None) print(config.restore) if config.restore: print('Restore from: {}'.format(config.restore)) saver.restore(sess, config.restore) iterator.initialize_dataset(sess) gamma_ = 1.0 lambda_D_ = 0.0 prev_acc = 0 # #Train discriminator. for epoch in range(1, config.discriminator_nepochs + 1): print("Epoch number:", epoch) iterator.restart_dataset(sess, ['train_discriminator']) val_acc = discriminator(sess, lambda_ae_, gamma_, lambda_D_, epoch, mode='train') if (val_acc > prev_acc): print("Accuracy is better, saving model") prev_acc = val_acc saver.save( sess, os.path.join(config.checkpoint_path, 'discriminator_only_ckpt'), epoch) else: print("Accuracy is worse") # Test discriminator. iterator.restart_dataset(sess, ['test_discriminator']) print('gamma:{}'.format(gamma_)) discriminator(sess, lambda_ae_, gamma_, lambda_D_, 1, mode='test') exit() # Train autoencoder for epoch in range(1, config.autoencoder_nepochs + 1): iterator.restart_dataset(sess, ['train_autoencoder']) autoencoder(sess, lambda_ae_, gamma_, lambda_D_, epoch, mode='train') saver.save( sess, os.path.join(config.checkpoint_path, 'discriminator_only_and_autoencoder_only_ckpt'), epoch) # Test autoencoder iterator.restart_dataset(sess, ['test_autoencoder']) autoencoder(sess, lambda_ae_, gamma_, lambda_D_, 1, mode='test') gamma_ = 1.0 lambda_D_ = 1.0 # # gamma_decay = 0.5 # Gumbel-softmax temperature anneal rate # Train Defender for epoch in range(0, config.full_nepochs): # gamma_ = max(0.001, gamma_ * 0.5) print('gamma: {}, lambda_ae: {}, lambda_D: {}'.format( gamma_, lambda_ae_, lambda_D_)) iterator.restart_dataset(sess, ['train_defender']) defender(sess, lambda_ae_, gamma_, lambda_D_, epoch, mode='train') saver.save(sess, os.path.join(config.checkpoint_path, 'full_ckpt'), epoch) # Test Defender iterator.restart_dataset(sess, 'test_defender') defender(sess, lambda_ae_, gamma_, lambda_D_, 1, mode='test')
def main(): # Data train_data = tx.data.MultiAlignedData(hparams=config.train_data, device=device) val_data = tx.data.MultiAlignedData(hparams=config.val_data, device=device) test_data = tx.data.MultiAlignedData(hparams=config.test_data, device=device) vocab = train_data.vocab(0) # Each training batch is used twice: once for updating the generator and # once for updating the discriminator. Feedable data iterator is used for # such case. iterator = tx.data.DataIterator({ 'train': train_data, 'val': val_data, 'test': test_data }) # Model gamma_ = 1. lambda_g_ = 0. # Model model = CtrlGenModel(vocab, hparams=config.model) model.to(device) # create optimizers train_op_d = tx.core.get_optimizer(params=model.d_vars, hparams=config.model['opt']) train_op_g = tx.core.get_optimizer(params=model.g_vars, hparams=config.model['opt']) train_op_g_ae = tx.core.get_optimizer(params=model.g_vars, hparams=config.model['opt']) def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): model.train() avg_meters_d = tx.utils.AverageRecorder(size=10) avg_meters_g = tx.utils.AverageRecorder(size=10) iterator.switch_to_dataset("train") step = 0 for batch in iterator: train_op_d.zero_grad() train_op_g_ae.zero_grad() train_op_g.zero_grad() step += 1 vals_d = model(batch, gamma_, lambda_g_, mode="train", component="D") loss_d = vals_d['loss_d'] loss_d.backward() train_op_d.step() recorder_d = { key: value.detach().cpu().data for (key, value) in vals_d.items() } avg_meters_d.add(recorder_d) vals_g = model(batch, gamma_, lambda_g_, mode="train", component="G") if epoch <= config.pretrain_nepochs: loss_g_ae = vals_g['loss_g_ae'] loss_g_ae.backward() train_op_g_ae.step() else: loss_g = vals_g['loss_g'] loss_g.backward() train_op_g.step() recorder_g = { key: value.detach().cpu().data for (key, value) in vals_g.items() } avg_meters_g.add(recorder_g) if verbose and (step == 1 or step % config.display == 0): print('step: {}, {}'.format(step, avg_meters_d.to_str(4))) print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) if verbose and step % config.display_eval == 0: _eval_epoch(gamma_, lambda_g_, epoch) print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4))) print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4))) @torch.no_grad() def _eval_epoch(gamma_, lambda_g_, epoch, val_or_test='val'): model.eval() avg_meters = tx.utils.AverageRecorder() iterator.switch_to_dataset(val_or_test) for batch in iterator: vals, samples = model(batch, gamma_, lambda_g_, mode='eval') batch_size = vals.pop('batch_size') # Computes BLEU hyps = tx.data.map_ids_to_strs(samples['transferred'].cpu(), vocab) refs = tx.data.map_ids_to_strs(samples['original'].cpu(), vocab) refs = np.expand_dims(refs, axis=1) bleu = tx.evals.corpus_bleu_moses(refs, hyps) vals['bleu'] = bleu avg_meters.add(vals, weight=batch_size) # Writes samples tx.utils.write_paired_text(refs.squeeze(), hyps, os.path.join(config.sample_path, 'val.%d' % epoch), append=True, mode='v') print('{}: {}'.format(val_or_test, avg_meters.to_str(precision=4))) return avg_meters.avg() os.makedirs(config.sample_path, exist_ok=True) os.makedirs(config.checkpoint_path, exist_ok=True) # Runs the logics if config.restore: print('Restore from: {}'.format(config.restore)) ckpt = torch.load(args.restore) model.load_state_dict(ckpt['model']) train_op_d.load_state_dict(ckpt['optimizer_d']) train_op_g.load_state_dict(ckpt['optimizer_g']) for epoch in range(1, config.max_nepochs + 1): if epoch > config.pretrain_nepochs: # Anneals the gumbel-softmax temperature gamma_ = max(0.001, gamma_ * config.gamma_decay) lambda_g_ = config.lambda_g print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_)) # Train _train_epoch(gamma_, lambda_g_, epoch) # Val _eval_epoch(gamma_, lambda_g_, epoch, 'val') states = { 'model': model.state_dict(), 'optimizer_d': train_op_d.state_dict(), 'optimizer_g': train_op_g.state_dict() } torch.save(states, os.path.join(config.checkpoint_path, 'ckpt')) # Test _eval_epoch(gamma_, lambda_g_, epoch, 'test')
def _main(_): # Data train_data = tx.data.MultiAlignedData(config.train_data) val_data = tx.data.MultiAlignedData(config.val_data) test_data = tx.data.MultiAlignedData(config.test_data) vocab = train_data.vocab(0) # Each training batch is used twice: once for updating the generator and # once for updating the discriminator. Feedable data iterator is used for # such case. iterator = tx.data.FeedableDataIterator({ 'train_g': train_data, 'train_d': train_data, 'val': val_data, 'test': test_data }) batch = iterator.get_next() # Model gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma') lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g') model = CtrlGenModel(batch, vocab, gamma, lambda_g, config.model) def _train_epoch(sess, gamma_, lambda_g_, epoch, verbose=True): avg_meters_d = tx.utils.AverageRecorder(size=10) avg_meters_g = tx.utils.AverageRecorder(size=10) step = 0 while True: try: step += 1 feed_dict = { iterator.handle: iterator.get_handle(sess, 'train_d'), gamma: gamma_, lambda_g: lambda_g_ } vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict) avg_meters_d.add(vals_d) feed_dict = { iterator.handle: iterator.get_handle(sess, 'train_g'), gamma: gamma_, lambda_g: lambda_g_ } vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict) avg_meters_g.add(vals_g) if verbose and (step == 1 or step % config.display == 0): print('step: {}, {}'.format(step, avg_meters_d.to_str(4))) print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) if verbose and step % config.display_eval == 0: iterator.restart_dataset(sess, 'val') _eval_epoch(sess, gamma_, lambda_g_, epoch) except tf.errors.OutOfRangeError: print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4))) print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4))) break def _eval_epoch(sess, gamma_, lambda_g_, epoch, val_or_test='val'): avg_meters = tx.utils.AverageRecorder() while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, val_or_test), gamma: gamma_, lambda_g: lambda_g_, tx.context.global_mode(): tf.estimator.ModeKeys.EVAL } vals = sess.run(model.fetches_eval, feed_dict=feed_dict) batch_size = vals.pop('batch_size') # Computes BLEU samples = tx.utils.dict_pop(vals, list(model.samples.keys())) hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab) refs = tx.utils.map_ids_to_strs(samples['original'], vocab) refs = np.expand_dims(refs, axis=1) bleu = tx.evals.corpus_bleu_moses(refs, hyps) tf.summary.scalar('Bleu', bleu) vals['bleu'] = bleu avg_meters.add(vals, weight=batch_size) # Writes samples tx.utils.write_paired_text(refs.squeeze(), hyps, os.path.join( config.sample_path, 'val.%d' % epoch), append=True, mode='v') merged = tf.summary.merge_all() writer = tf.summary.FileWriter('./summary', sess.graph) result = sess.run(merged) writer.add_summary(result) except tf.errors.OutOfRangeError: print('{}: {}'.format(val_or_test, avg_meters.to_str(precision=4))) break return avg_meters.avg() tf.gfile.MakeDirs(config.sample_path) tf.gfile.MakeDirs(config.checkpoint_path) # Runs the logics with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) saver = tf.train.Saver(max_to_keep=None) if config.restore: print('Restore from: {}'.format(config.restore)) saver.restore(sess, config.restore) iterator.initialize_dataset(sess) gamma_ = 1. lambda_g_ = 0. for epoch in range(1, config.max_nepochs + 1): if epoch > config.pretrain_nepochs: # Anneals the gumbel-softmax temperature gamma_ = max(0.001, gamma_ * config.gamma_decay) lambda_g_ = config.lambda_g print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_)) # Train iterator.restart_dataset(sess, ['train_g', 'train_d']) _train_epoch(sess, gamma_, lambda_g_, epoch) # Val iterator.restart_dataset(sess, 'val') _eval_epoch(sess, gamma_, lambda_g_, epoch, 'val') saver.save(sess, os.path.join(config.checkpoint_path, 'ckpt'), epoch) # Test iterator.restart_dataset(sess, 'test') _eval_epoch(sess, gamma_, lambda_g_, epoch, 'test')