def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset, num_examples_for_eval): """Evaluates a pair generator discriminator. This function loads a discriminator from disk, a generator, and evaluates the discriminator against the generator. It returns the mean probability of the discriminator against several batches, and the FID of the generator against the validation data. It also writes evaluation samples to disk. Args: config: dict, the config file. batch_size: int, size of the batch. checkpoint_path: string, full path to the TF checkpoint on disk. data_dir: string, path to a directory containing the dataset. dataset: string, "emnlp2017", to select the right dataset. num_examples_for_eval: int, number of examples for evaluation. """ tf.reset_default_graph() logging.info("Evaluating checkpoint %s.", checkpoint_path) # Build graph. train_data, valid_data, word_to_id = reader.get_raw_data(data_dir, dataset=dataset) id_to_word = {v: k for k, v in word_to_id.iteritems()} vocab_size = len(word_to_id) train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size) valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size) train_sequence = tf.placeholder( dtype=tf.int32, shape=[batch_size, reader.MAX_TOKENS_SEQUENCE[dataset]], name="train_sequence") train_sequence_length = tf.placeholder(dtype=tf.int32, shape=[batch_size], name="train_sequence_length") valid_sequence = tf.placeholder( dtype=tf.int32, shape=[batch_size, reader.MAX_TOKENS_SEQUENCE[dataset]], name="valid_sequence") valid_sequence_length = tf.placeholder(dtype=tf.int32, shape=[batch_size], name="valid_sequence_length") disc_inputs_train = { "sequence": train_sequence, "sequence_length": train_sequence_length, } disc_inputs_valid = { "sequence": valid_sequence, "sequence_length": valid_sequence_length, } if config.use_pretrained_embedding: embedding_source = utils.get_embedding_path(config.data_dir, config.dataset) vocab_file = "/tmp/vocab.txt" with gfile.GFile(vocab_file, "w") as f: for i in xrange(len(id_to_word)): f.write(id_to_word[i] + "\n") logging.info("Temporary vocab file: %s", vocab_file) else: embedding_source = None vocab_file = None gen = generators.LSTMGen( vocab_size=vocab_size, feature_sizes=[config.gen_feature_size] * config.num_layers_gen, max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset], batch_size=config.batch_size, use_layer_norm=config.layer_norm_gen, trainable_embedding_size=config.trainable_embedding_size, input_dropout=config.gen_input_dropout, output_dropout=config.gen_output_dropout, pad_token=reader.PAD_INT, embedding_source=embedding_source, vocab_file=vocab_file, ) gen_outputs = gen() disc = discriminator_nets.LSTMEmbedDiscNet( vocab_size=vocab_size, feature_sizes=[config.disc_feature_size] * config.num_layers_disc, trainable_embedding_size=config.trainable_embedding_size, embedding_source=embedding_source, use_layer_norm=config.layer_norm_disc, pad_token=reader.PAD_INT, vocab_file=vocab_file, dropout=config.disc_dropout, ) disc_inputs = { "sequence": gen_outputs["sequence"], "sequence_length": gen_outputs["sequence_length"], } gen_logits = disc(**disc_inputs) train_logits = disc(**disc_inputs_train) valid_logits = disc(**disc_inputs_valid) # Saver. saver = tf.train.Saver() # Reduce over time and batch. train_probs = tf.reduce_mean(tf.nn.sigmoid(train_logits)) valid_probs = tf.reduce_mean(tf.nn.sigmoid(valid_logits)) gen_probs = tf.reduce_mean(tf.nn.sigmoid(gen_logits)) outputs = { "train_probs": train_probs, "valid_probs": valid_probs, "gen_probs": gen_probs, "gen_sequences": gen_outputs["sequence"], "valid_sequences": valid_sequence } # Get average discriminator score and store generated sequences. all_valid_sentences = [] all_gen_sentences = [] all_gen_sequences = [] mean_train_prob = 0.0 mean_valid_prob = 0.0 mean_gen_prob = 0.0 logging.info("Graph constructed, generating batches.") num_batches = num_examples_for_eval // batch_size + 1 # Restrict the thread pool size to prevent excessive GCU usage on Borg. tf_config = tf.ConfigProto() tf_config.intra_op_parallelism_threads = 16 tf_config.inter_op_parallelism_threads = 16 with tf.Session(config=tf_config) as sess: # Restore variables from checkpoints. logging.info("Restoring variables.") saver.restore(sess, checkpoint_path) for i in xrange(num_batches): logging.info("Batch %d / %d", i, num_batches) train_data_np = train_iterator.next() valid_data_np = valid_iterator.next() feed_dict = { train_sequence: train_data_np["sequence"], train_sequence_length: train_data_np["sequence_length"], valid_sequence: valid_data_np["sequence"], valid_sequence_length: valid_data_np["sequence_length"], } outputs_np = sess.run(outputs, feed_dict=feed_dict) all_gen_sequences.extend(outputs_np["gen_sequences"]) gen_sentences = utils.batch_sequences_to_sentences( outputs_np["gen_sequences"], id_to_word) valid_sentences = utils.batch_sequences_to_sentences( outputs_np["valid_sequences"], id_to_word) all_valid_sentences.extend(valid_sentences) all_gen_sentences.extend(gen_sentences) mean_train_prob += outputs_np["train_probs"] / batch_size mean_valid_prob += outputs_np["valid_probs"] / batch_size mean_gen_prob += outputs_np["gen_probs"] / batch_size logging.info("Evaluating FID.") # Compute FID fid = eval_metrics.fid( generated_sentences=all_gen_sentences[:num_examples_for_eval], real_sentences=all_valid_sentences[:num_examples_for_eval]) utils.write_eval_results(config.checkpoint_dir, all_gen_sentences, os.path.basename(checkpoint_path), mean_train_prob, mean_valid_prob, mean_gen_prob, fid)
def train(config): """Train.""" logging.info("Training.") tf.reset_default_graph() np.set_printoptions(precision=4) # Get data. raw_data = reader.get_raw_data(data_path=config.data_dir, dataset=config.dataset) train_data, valid_data, word_to_id = raw_data id_to_word = {v: k for k, v in word_to_id.iteritems()} vocab_size = len(word_to_id) max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset] logging.info("Vocabulary size: %d", vocab_size) iterator = reader.iterator(raw_data=train_data, batch_size=config.batch_size) iterator_valid = reader.iterator(raw_data=valid_data, batch_size=config.batch_size) real_sequence = tf.placeholder(dtype=tf.int32, shape=[config.batch_size, max_length], name="real_sequence") real_sequence_length = tf.placeholder(dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length") first_batch_np = iterator.next() valid_batch_np = iterator_valid.next() test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()} test_fake_batch = { "sequence": tf.constant( np.random.choice(vocab_size, size=[config.batch_size, max_length]).astype(np.int32)), "sequence_length": tf.constant( np.random.choice(max_length, size=[config.batch_size]).astype(np.int32)), } valid_batch = {k: tf.constant(v) for k, v in valid_batch_np.items()} # Create generator. if config.use_pretrained_embedding: embedding_source = utils.get_embedding_path(config.data_dir, config.dataset) vocab_file = "/tmp/vocab.txt" with gfile.GFile(vocab_file, "w") as f: for i in xrange(len(id_to_word)): f.write(id_to_word[i] + "\n") logging.info("Temporary vocab file: %s", vocab_file) else: embedding_source = None vocab_file = None gen = generators.LSTMGen( vocab_size=vocab_size, feature_sizes=[config.gen_feature_size] * config.num_layers_gen, max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset], batch_size=config.batch_size, use_layer_norm=config.layer_norm_gen, trainable_embedding_size=config.trainable_embedding_size, input_dropout=config.gen_input_dropout, output_dropout=config.gen_output_dropout, pad_token=reader.PAD_INT, embedding_source=embedding_source, vocab_file=vocab_file, ) gen_outputs = gen() # Create discriminator. disc = discriminator_nets.LSTMEmbedDiscNet( vocab_size=vocab_size, feature_sizes=[config.disc_feature_size] * config.num_layers_disc, trainable_embedding_size=config.trainable_embedding_size, embedding_source=embedding_source, use_layer_norm=config.layer_norm_disc, pad_token=reader.PAD_INT, vocab_file=vocab_file, dropout=config.disc_dropout, ) disc_logits_real = disc(sequence=real_sequence, sequence_length=real_sequence_length) disc_logits_fake = disc(sequence=gen_outputs["sequence"], sequence_length=gen_outputs["sequence_length"]) # Loss of the discriminator. if config.disc_loss_type == "ce": targets_real = tf.ones( [config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]]) targets_fake = tf.zeros( [config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]]) loss_real = losses.sequential_cross_entropy_loss( disc_logits_real, targets_real) loss_fake = losses.sequential_cross_entropy_loss( disc_logits_fake, targets_fake) disc_loss = 0.5 * loss_real + 0.5 * loss_fake # Loss of the generator. gen_loss, cumulative_rewards, baseline = losses.reinforce_loss( disc_logits=disc_logits_fake, gen_logprobs=gen_outputs["logprobs"], gamma=config.gamma, decay=config.baseline_decay) # Optimizers disc_optimizer = tf.train.AdamOptimizer(learning_rate=config.disc_lr, beta1=config.disc_beta1) gen_optimizer = tf.train.AdamOptimizer(learning_rate=config.gen_lr, beta1=config.gen_beta1) # Get losses and variables. disc_vars = disc.get_all_variables() gen_vars = gen.get_all_variables() l2_disc = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in disc_vars])) l2_gen = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in gen_vars])) scalar_disc_loss = tf.reduce_mean(disc_loss) + config.l2_disc * l2_disc scalar_gen_loss = tf.reduce_mean(gen_loss) + config.l2_gen * l2_gen # Update ops. global_step = tf.train.get_or_create_global_step() disc_update = disc_optimizer.minimize(scalar_disc_loss, var_list=disc_vars, global_step=global_step) gen_update = gen_optimizer.minimize(scalar_gen_loss, var_list=gen_vars, global_step=global_step) # Saver. saver = tf.train.Saver() # Metrics test_disc_logits_real = disc(**test_real_batch) test_disc_logits_fake = disc(**test_fake_batch) valid_disc_logits = disc(**valid_batch) disc_predictions_real = tf.nn.sigmoid(disc_logits_real) disc_predictions_fake = tf.nn.sigmoid(disc_logits_fake) valid_disc_predictions = tf.reduce_mean(tf.nn.sigmoid(valid_disc_logits), axis=0) test_disc_predictions_real = tf.reduce_mean( tf.nn.sigmoid(test_disc_logits_real), axis=0) test_disc_predictions_fake = tf.reduce_mean( tf.nn.sigmoid(test_disc_logits_fake), axis=0) # Only log results for the first element of the batch. metrics = { "scalar_gen_loss": scalar_gen_loss, "scalar_disc_loss": scalar_disc_loss, "disc_predictions_real": tf.reduce_mean(disc_predictions_real), "disc_predictions_fake": tf.reduce_mean(disc_predictions_fake), "test_disc_predictions_real": tf.reduce_mean(test_disc_predictions_real), "test_disc_predictions_fake": tf.reduce_mean(test_disc_predictions_fake), "valid_disc_predictions": tf.reduce_mean(valid_disc_predictions), "cumulative_rewards": tf.reduce_mean(cumulative_rewards), "baseline": tf.reduce_mean(baseline), } # Training. logging.info("Starting training") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) latest_ckpt = tf.train.latest_checkpoint(config.checkpoint_dir) if latest_ckpt: saver.restore(sess, latest_ckpt) for step in xrange(config.num_steps): real_data_np = iterator.next() train_feed = { real_sequence: real_data_np["sequence"], real_sequence_length: real_data_np["sequence_length"], } # Update generator and discriminator. for _ in xrange(config.num_disc_updates): sess.run(disc_update, feed_dict=train_feed) for _ in xrange(config.num_gen_updates): sess.run(gen_update, feed_dict=train_feed) # Reporting if step % config.export_every == 0: gen_sequence_np, metrics_np = sess.run( [gen_outputs["sequence"], metrics], feed_dict=train_feed) metrics_np["gen_sentence"] = utils.sequence_to_sentence( gen_sequence_np[0, :], id_to_word) saver.save(sess, save_path=config.checkpoint_dir + "scratchgan", global_step=global_step) metrics_np["model_path"] = tf.train.latest_checkpoint( config.checkpoint_dir) logging.info(metrics_np) # After training, export models. saver.save(sess, save_path=config.checkpoint_dir + "scratchgan", global_step=global_step) logging.info("Saved final model at %s.", tf.train.latest_checkpoint(config.checkpoint_dir))