def main(_): """ Builds the model and runs """ np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) nsamples = FLAGS.nsamples batch_size = FLAGS.batch_size max_decoding_length = FLAGS.max_decoding_length # Load GPT-2 model configuration if FLAGS.config_type == "json": gpt2_config = model_utils.transform_gpt2_to_texar_config( FLAGS.config_model) elif FLAGS.config_type == "texar": gpt2_config = importlib.import_module( FLAGS.config_model) else: raise ValueError("Unknown config_type.") assert max_decoding_length <= gpt2_config.position_size, ( "max_decoding_length should not be greater than position size") assert nsamples % batch_size == 0, ( "nsamples must be dividable by batch_size") # Create a data pre-processor for, e.g., BPE encoding proc = processor.get_encoder( FLAGS.pretrained_model_dir) context = tf.placeholder(tf.int32, [batch_size, None]) context_length = tf.placeholder(tf.int32, [batch_size]) end_token = proc.encoder["<|endoftext|>"] if FLAGS.is_interactive: start_tokens = context[:, 0] else: start_tokens = tf.fill([batch_size], end_token) # Build the GPT-2 model word_embedder = tx.modules.WordEmbedder( vocab_size=gpt2_config.vocab_size, hparams=gpt2_config.embed) pos_embedder = tx.modules.PositionEmbedder( position_size=gpt2_config.position_size, hparams=gpt2_config.pos_embed) def _embedding_fn(x, y): # `x` is token ids, `y` is time steps return word_embedder(x) + pos_embedder(y) helper = tx.modules.TopKSampleEmbeddingHelper( embedding=_embedding_fn, start_tokens=start_tokens, end_token=end_token, top_k=FLAGS.top_k, softmax_temperature=FLAGS.temperature) output_layer = tf.transpose(word_embedder.embedding, (1, 0)) decoder = tx.modules.TransformerDecoder( vocab_size=gpt2_config.vocab_size, output_layer=output_layer, hparams=gpt2_config.decoder) with tf.Session() as sess: # Generate continuations of context lm_output, _ = decoder( context=context, context_sequence_length=context_length, max_decoding_length=max_decoding_length, helper=helper, mode=tf.estimator.ModeKeys.PREDICT) # Load model checkpoint if FLAGS.checkpoint: tf.logging.info("Restore from {}".format(FLAGS.checkpoint)) saver = tf.train.Saver() saver.restore(sess, FLAGS.checkpoint) elif FLAGS.pretrain_checkpoint: model_utils.init_gpt2_checkpoint( sess, FLAGS.pretrain_checkpoint) print("\nFinished loading\n") if FLAGS.is_interactive: # Enter interactive mode while True: story_title = input("Please enter a title! or q to exit >>> ") if story_title == "q": break emotion_arc_poem = input("Please enter a sequence of emotions, one for each stanza. Choose from: Beauty, Joy, Vitality, Humor, Uneasiness, Sadness, Suspense, Annoyance, Nostalgia, Awe, Sublime. Beauty/Joy and Awe/Sublime refer to the same emotion internally.\n>>> ") if emotion_arc_poem == "q": break # raw_text = raw_text + " | " while not story_title: print("Input should not be empty!") story_title = input("Please enter a title! or q to exit >>> ") emotion_arc_poem = input("Please enter a sequence of three emotions separated by space from joy, anger, sadness, fear, neutral! or q to exit >>> ") for _ in range(FLAGS.poem_batch_size): for emotion in emotion_arc_poem.split(): emotion_arc = normalize_emo(emotion) print(emotion_arc) raw_text = " <$> ".join((emotion_arc, story_title)) print(raw_text) context_tokens = proc.encode(raw_text) feed_dict = { context: [context_tokens for _ in range(batch_size)], context_length: [len(context_tokens) for _ in range(batch_size)], tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT } generated = 0 for _ in range(nsamples // batch_size): output = sess.run(lm_output, feed_dict=feed_dict) sample_id = output.sample_id for i in range(batch_size): si = sample_id[i][len(context_tokens):] s_text = proc.decode(si) s_text = normalize_stanza(s_text) print(s_text) # end of poem print("=" * 80) elif FLAGS.is_eval: eval_arcs_total = [] for arc in eval_arcs: for _ in range(FLAGS.eval_poems_per_arc): eval_arcs_total.append(arc) random.shuffle(eval_arcs_total) eval_arcs_file = open("eval_arcs.txt", "w") poems_file = open("eval_poems.txt", "w") for arc in eval_arcs_total: eval_arcs_file.write(arc + "\n") for emotion_arc_poem in eval_arcs_total: story_title = eval_title for emotion in emotion_arc_poem.split(): emotion_arc = normalize_emo(emotion) print("Generating: " + story_title + " / " + emotion_arc) raw_text = " <$> ".join((emotion_arc, story_title)) context_tokens = proc.encode(raw_text) feed_dict = { context: [context_tokens for _ in range(batch_size)], context_length: [len(context_tokens) for _ in range(batch_size)], tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT } generated = 0 for _ in range(nsamples // batch_size): output = sess.run(lm_output, feed_dict=feed_dict) sample_id = output.sample_id for i in range(batch_size): si = sample_id[i][len(context_tokens):] s_text = proc.decode(si) s_text = normalize_stanza(s_text) print(s_text) print() poems_file.write(s_text + "\n\n") # end of poem print("=" * 80) poems_file.write("="*80+"\n")
def main(_): """ Builds the model and runs """ np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) nsamples = FLAGS.nsamples batch_size = FLAGS.batch_size max_decoding_length = FLAGS.max_decoding_length # Load GPT-2 model configuration if FLAGS.config_type == "json": gpt2_config = model_utils.transform_gpt2_to_texar_config( FLAGS.config_model) elif FLAGS.config_type == "texar": gpt2_config = importlib.import_module(FLAGS.config_model) else: raise ValueError("Unknown config_type.") assert max_decoding_length <= gpt2_config.position_size, ( "max_decoding_length should not be greater than position size") assert nsamples % batch_size == 0, ( "nsamples must be dividable by batch_size") # Create a data pre-processor for, e.g., BPE encoding proc = processor.get_encoder(FLAGS.pretrained_model_dir) context = tf.placeholder(tf.int32, [batch_size, None]) context_length = tf.placeholder(tf.int32, [batch_size]) end_token = proc.encoder["<|endoftext|>"] if FLAGS.is_interactive: start_tokens = context[:, 0] else: start_tokens = tf.fill([batch_size], end_token) # Build the GPT-2 model word_embedder = tx.modules.WordEmbedder(vocab_size=gpt2_config.vocab_size, hparams=gpt2_config.embed) pos_embedder = tx.modules.PositionEmbedder( position_size=gpt2_config.position_size, hparams=gpt2_config.pos_embed) def _embedding_fn(x, y): # `x` is token ids, `y` is time steps return word_embedder(x) + pos_embedder(y) helper = tx.modules.TopKSampleEmbeddingHelper( embedding=_embedding_fn, start_tokens=start_tokens, end_token=end_token, top_k=FLAGS.top_k, softmax_temperature=FLAGS.temperature) output_layer = tf.transpose(word_embedder.embedding, (1, 0)) decoder = tx.modules.TransformerDecoder(vocab_size=gpt2_config.vocab_size, output_layer=output_layer, hparams=gpt2_config.decoder) with tf.Session() as sess: if FLAGS.is_interactive: # Generate continuations of context lm_output, _ = decoder(context=context, context_sequence_length=context_length, max_decoding_length=max_decoding_length, helper=helper, mode=tf.estimator.ModeKeys.PREDICT) # Load model checkpoint if FLAGS.checkpoint: tf.logging.info("Restore from {}".format(FLAGS.checkpoint)) saver = tf.train.Saver() saver.restore(sess, FLAGS.checkpoint) elif FLAGS.pretrain_checkpoint: model_utils.init_gpt2_checkpoint(sess, FLAGS.pretrain_checkpoint) print("\nFinished loading\n") # Enter interactive mode while True: story_title = input("Please enter a title! or q to exit >>> ") if story_title == "q": break emotion_arc = input( "Please enter a sequence of three emotions separated by space from joy, anger, sadness, fear, neutral! for example: joy sadness sadness, or q to exit >>> " ) if emotion_arc == "q": break # raw_text = raw_text + " | " while not story_title: print("Input should not be empty!") story_title = input( "Please enter a title! or q to exit >>> ") emotion_arc = input( "Please enter a sequence of three emotions separated by space from joy, anger, sadness, fear, neutral! or q to exit >>> " ) raw_text = " <$> ".join((emotion_arc, story_title)) context_tokens = proc.encode(raw_text) feed_dict = { context: [context_tokens for _ in range(batch_size)], context_length: [len(context_tokens) for _ in range(batch_size)], tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT } generated = 0 for _ in range(nsamples // batch_size): output = sess.run(lm_output, feed_dict=feed_dict) sample_id = output.sample_id for i in range(batch_size): generated += 1 print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) si = sample_id[i][len(context_tokens):] s_text = proc.decode(si) s_text = s_text[:s_text.find('<|endoftext|>')].strip( ) if '<|endoftext|>' in s_text else s_text print(s_text) print("=" * 80) else: # Generate samples from scratch lm_output, _ = decoder(max_decoding_length=max_decoding_length, helper=helper, mode=tf.estimator.ModeKeys.PREDICT) # Load model checkpoint if FLAGS.checkpoint: tf.logging.info("Restore from {}".format(FLAGS.checkpoint)) saver = tf.train.Saver() saver.restore(sess, FLAGS.checkpoint) elif FLAGS.pretrain_checkpoint: model_utils.init_gpt2_checkpoint(sess, FLAGS.pretrain_checkpoint) print("\nFinished loading\n") feed_dict = { tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT } generated = 0 while nsamples == 0 or generated < nsamples: output = sess.run(lm_output, feed_dict=feed_dict) sample_id = output.sample_id for i in range(batch_size): generated += batch_size text = proc.decode(sample_id[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text)
def main(_): """ Builds the model and runs """ if FLAGS.distributed: import horovod.tensorflow as hvd hvd.init() tf.logging.set_verbosity(tf.logging.INFO) if len(config_train.name) > 0: output_dir = os.path.join(FLAGS.output_dir, config_train.name) else: output_dir = FLAGS.output_dir tx.utils.maybe_create_dir(output_dir) ## Loads GPT-2 model configuration if FLAGS.config_type == "json": gpt2_config = model_utils.transform_gpt2_to_texar_config( FLAGS.config_model) elif FLAGS.config_type == 'texar': gpt2_config = importlib.import_module(FLAGS.config_model) else: raise ValueError('Unknown config_type.') # Creates a data pre-processor for, e.g., BPE encoding proc = processor.get_encoder(FLAGS.pretrained_model_dir) end_token = proc.encoder['<|endoftext|>'] max_decoding_length = config_train.max_decoding_length assert max_decoding_length <= gpt2_config.position_size, ( "max_decoding_length should not be greater than position_size. " "{}>{}".format(max_decoding_length, gpt2_config.position_size)) ## Loads data # Configures training data shard in distribued mode if FLAGS.distributed: config_train.train_hparam["dataset"]["num_shards"] = hvd.size() config_train.train_hparam["dataset"]["shard_id"] = hvd.rank() config_train.train_hparam["batch_size"] //= hvd.size() datasets = {} #if FLAGS.do_train: train_dataset = tx.data.TFRecordData(hparams=config_train.train_hparam) datasets['train'] = train_dataset #if FLAGS.do_eval: dev_dataset = tx.data.TFRecordData(hparams=config_train.dev_hparam) datasets['dev'] = dev_dataset #if FLAGS.do_test: test_dataset = tx.data.TFRecordData(hparams=config_train.test_hparam) datasets['test'] = test_dataset iterator = tx.data.FeedableDataIterator(datasets) batch = iterator.get_next() batch_size = tf.shape(batch['x1x4_ids'])[0] ## Builds the GPT-2 model vocab_size = gpt2_config.vocab_size word_embedder = tx.modules.WordEmbedder(vocab_size=vocab_size, hparams=gpt2_config.embed) pos_embedder = tx.modules.PositionEmbedder( position_size=gpt2_config.position_size, hparams=gpt2_config.pos_embed) # Ties output layer with input word embedding output_layer = tf.transpose(word_embedder.embedding, (1, 0)) decoder = tx.modules.TransformerDecoder(vocab_size=vocab_size, output_layer=output_layer, hparams=gpt2_config.decoder) def _embedding_fn(ids, times): return word_embedder(ids) + pos_embedder(times) # For training def _get_recon_loss(ids, full_len, prefix_len=None, mask_prefix=True, do_print=False): ids = ids[:, :tf.reduce_max(full_len)] batch_size__ = tf.shape(ids)[0] seq_len = tf.fill([batch_size__], tf.shape(ids)[1]) pos_embeds = pos_embedder(sequence_length=seq_len) input_embeds = word_embedder(ids) + pos_embeds # greedy output outputs = decoder(inputs=input_embeds, decoding_strategy='train_greedy') max_full_len = tf.reduce_max(full_len) ids = ids[:, :max_full_len] logits = outputs.logits[:, :max_full_len] if mask_prefix: loss_recon = tx.losses.sequence_sparse_softmax_cross_entropy( labels=ids[:, 1:], logits=logits[:, :-1, :], sequence_length=full_len - 1, average_across_timesteps=False, sum_over_timesteps=False, average_across_batch=False, sum_over_batch=False) mask_recon = tf.sequence_mask(full_len - 1, dtype=tf.float32) mask_recon_prefix = 1 - tf.sequence_mask( prefix_len - 1, maxlen=max_full_len - 1, #max_decoding_length-1, dtype=tf.float32) mask_recon = mask_recon * mask_recon_prefix if do_print: print_op_1 = tf.print(mask_recon) loss_recon_flat = tx.utils.reduce_with_weights( tensor=loss_recon, weights=mask_recon, average_across_remaining=False, sum_over_remaining=False, average_across_batch=False) print_op_2 = tf.print(loss_recon_flat) with tf.control_dependencies([print_op_1, print_op_2]): loss_recon = tx.utils.reduce_with_weights( tensor=loss_recon, weights=mask_recon, average_across_remaining=True, sum_over_remaining=False) return loss_recon, mask_recon, loss_recon_flat else: loss_recon = tx.utils.reduce_with_weights( tensor=loss_recon, weights=mask_recon, average_across_remaining=True, sum_over_remaining=False) else: loss_recon = tx.losses.sequence_sparse_softmax_cross_entropy( labels=ids[:, 1:], logits=logits[:, :-1, :], sequence_length=full_len - 1, average_across_timesteps=True, sum_over_timesteps=False, average_across_batch=False, sum_over_batch=False) return loss_recon # For RL fine-tuning def _get_sample_story(context_ids, context_len): sample_output, sample_len = decoder( decoding_strategy='infer_sample', embedding=_embedding_fn, context=context_ids, context_sequence_length=context_len, max_decoding_length=max_decoding_length, end_token=end_token, softmax_temperature=FLAGS.temperature, mode=tf.estimator.ModeKeys.PREDICT) return sample_output, sample_len # return ids, batch_loss, ids_len def _get_sample_rolled(output, length, context_len): ids = output.sample_id ids = tx.utils.varlength_roll(ids, -context_len) # final sample ids rolled ids_len = length - context_len ids = ids[:, :tf.reduce_max(ids_len)] return ids, ids_len def compute_batch_loss(output, sample_len, context_len): max_full_len = tf.reduce_max(sample_len) ids = output.sample_id[:, :max_full_len] logits = output.logits[:, :max_full_len] #(bs, sl, vocab) sampleLogprobs = tx.losses.sequence_sparse_softmax_cross_entropy( labels=ids[:, 1:], logits=logits, sequence_length=sample_len - 1, average_across_timesteps=False, sum_over_timesteps=False, average_across_batch=False, sum_over_batch=False) mask = tf.sequence_mask(sample_len - 1, dtype=tf.float32) mask_prefix = 1 - tf.sequence_mask( context_len - 1, maxlen=max_full_len - 1, #max_decoding_length-1, dtype=tf.float32) mask = mask * mask_prefix batch_loss = tx.utils.reduce_with_weights( tensor=sampleLogprobs, weights=mask, average_across_batch=False, average_across_remaining=True, sum_over_remaining=False) return batch_loss def _get_greedy_story(context_ids, context_len): greedy_res, greedy_len = decoder( decoding_strategy='infer_greedy', embedding=_embedding_fn, context=context_ids, context_sequence_length=context_len, max_decoding_length=max_decoding_length, end_token=end_token, mode=tf.estimator.ModeKeys.PREDICT) greedy_ids = tx.utils.varlength_roll(greedy_res.sample_id, -context_len) greedy_ids_len = greedy_len - context_len greedy_ids = greedy_ids[:, :tf.reduce_max(greedy_ids_len)] return greedy_ids, greedy_ids_len ## ROC Loss-1: ML loss x1_len = tf.placeholder(tf.int32, shape=[None], name='x1_len') x1x4_ids = tf.placeholder(tf.int32, shape=[None, None], name='x1x4_ids') x1x4_len = tf.placeholder(tf.int32, shape=[None], name='x1x4_len') loss_fine = _get_recon_loss(x1x4_ids, x1x4_len, x1_len) x1_ids = tf.placeholder(tf.int32, shape=[None, None], name='x1_ids') reward = tf.placeholder(tf.float32, shape=[None], name="reward") sampled_story = tf.placeholder(tf.int32, shape=[None, None], name="sampled_story") #smilar to sample_que sampled_story_len = tf.placeholder(tf.int32, shape=[None], name='sample_story_len') ## Loss-2: RL loss symbols_output, symbols_len = _get_sample_story(x1_ids, x1_len) symbols_rl, len_rl = _get_sample_rolled(symbols_output, symbols_len, x1_len) symbols_gr, len_gr = _get_greedy_story(x1_ids, x1_len) batch_loss_rl = _get_recon_loss(sampled_story, sampled_story_len, mask_prefix=False) rl_loss_fine = tf.reduce_mean(batch_loss_rl * reward) def _get_beam_ids(context_ids, context_len, target): # beam-search predictions = decoder(beam_width=5, length_penalty=config_train.length_penalty, embedding=_embedding_fn, context=context_ids, context_sequence_length=context_len, max_decoding_length=max_decoding_length, end_token=end_token, mode=tf.estimator.ModeKeys.PREDICT) beam_output_ids = tx.utils.varlength_roll( predictions["sample_id"][:, :, 0], -context_len) target_ids = tx.utils.varlength_roll(target, -context_len) return beam_output_ids, target_ids target_ids = tx.utils.varlength_roll(x1x4_ids, -x1_len) tau = tf.placeholder(tf.float32, shape=[], name='tau') if not FLAGS.sc_rl: loss = config_train.w_fine * loss_fine loss_dict = { 'loss': loss, 'loss_fine': config_train.w_fine * loss_fine, } else: loss = (1 - config_train.w_rl ) * config_train.w_fine * loss_fine + config_train.w_rl * ( config_train.w_fine_rl * rl_loss_fine) # loss_dict = { 'loss': loss, 'loss_fine': (1 - config_train.w_rl) * config_train.w_fine * loss_fine, 'rl_loss_fine': config_train.w_rl * config_train.w_fine_rl * rl_loss_fine, } ## Inference def _infer(context_name): helper = tx.modules.TopKSampleEmbeddingHelper( embedding=_embedding_fn, start_tokens=batch['%s_ids' % context_name][:, 0], end_token=end_token, top_k=FLAGS.top_k, softmax_temperature=FLAGS.temperature) outputs_infer, len_infer = decoder( context=batch['%s_ids' % context_name], context_sequence_length=batch['%s_len' % context_name], max_decoding_length=max_decoding_length, helper=helper) # outputs_infer contains sample_id and logits yy_ids = tx.utils.varlength_roll( outputs_infer.sample_id, -batch['%s_len' % context_name]) # shift beginning indices (context) to end yy_len = len_infer - batch['%s_len' % context_name] yy_ids = yy_ids[:, :tf.reduce_max(yy_len)] return yy_ids, yy_len x4_ids_fine, x4_len_fine = _infer('x1') def _infer_beam_ids(context_name): # beam-search predictions = decoder(beam_width=5, length_penalty=config_train.length_penalty, embedding=_embedding_fn, context=batch['%s_ids' % context_name], context_sequence_length=batch['%s_len' % context_name], max_decoding_length=max_decoding_length, end_token=end_token, mode=tf.estimator.ModeKeys.PREDICT) beam_output_ids = tx.utils.varlength_roll( predictions["sample_id"][:, :, 0], -batch['%s_len' % context_name]) return beam_output_ids beam_search_ids = _infer_beam_ids('x1') ## Optimization trainable_variables = tx.utils.collect_trainable_variables( [word_embedder, pos_embedder, decoder]) global_step = tf.Variable(0, trainable=False) opt = tx.core.get_optimizer(global_step=global_step, hparams=config_train.opt) if FLAGS.distributed: opt = hvd.DistributedOptimizer(opt) train_op = tf.contrib.layers.optimize_loss(loss=loss, global_step=global_step, learning_rate=None, optimizer=opt, variables=trainable_variables) ## Train/eval/test routine saver = tf.train.Saver() saver_best = tf.train.Saver(max_to_keep=1) dev_best = { 'loss': 1e8, 'loss_fine': 1e8, 'rl_loss_fine': 1e8, 'best_reward': -1e8, 'bleu': 0., 'meteor': 0. } #'best_reward': -1e8 def _log_losses(losses, step=None): loss_str = 'loss: %.4f, loss_fine: %.4f, rl_loss_fine: %.4f' % \ (losses['loss'], losses['loss_fine'], losses['rl_loss_fine'] ) if step is not None: loss_str = 'step: %d, %s' % (step, loss_str) _log(loss_str) def _is_head(): if not FLAGS.distributed: return True else: return hvd.rank() == 0 def _train_epoch(sess, initial=False): """Trains on the training set, and evaluates on the dev set periodically. """ # load train arc label data train_arc_file = [ i.strip().split() for i in open( os.path.join(config_train.arc_data, "train_mapped.txt")) ] iterator.restart_dataset(sess, 'train') while True: try: # (1) Get data and yy sample fetches_data = { 'batch': batch, 'batch_size': batch_size, } feed_dict_data = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets_data = sess.run(fetches_data, feed_dict_data) reward_fetches = { 'sample_rl': symbols_rl, 'sample_len': len_rl, 'greedy_sym': symbols_gr, 'greedy_len': len_gr, } reward_rets = sess.run(reward_fetches, feed_dict={ x1_ids: rets_data['batch']['x1_ids'], x1_len: rets_data['batch']['x1_len'], tx.global_mode(): tf.estimator.ModeKeys.PREDICT }) # prepare sample stories for classification ids_rl, text_rl = _get_text( proc, reward_rets['sample_rl'], reward_rets['sample_len']) #list of list story_rl = format_generated_stories_for_clf( text_rl, FLAGS.rl_method) #print("Rl Story: ", story_rl) _, text_base = _get_text(proc, reward_rets['greedy_sym'], reward_rets['greedy_len']) story_base = format_generated_stories_for_clf( text_base, FLAGS.rl_method) #print("Greedy Story", story_base) # add reward calculation here reward_rl = get_reward(predictor, story_rl, rets_data['batch']['unique_id'], train_arc_file, method=FLAGS.rl_method) reward_base = get_reward(predictor, story_base, rets_data['batch']['unique_id'], train_arc_file, method=FLAGS.rl_method) # self-critical reward reward_sc = [ rr - rb for rr, rb in zip(reward_rl, reward_base) ] # class list # print(reward_rl, reward_base, reward_sc) ids_rl = utils.list_strip_eos(ids_rl, end_token) new_in_sample_ids, new_in_sample_len = _fix(ids_rl, end_token) # (2) Optimize loss feed_dict = { x1_ids: rets_data['batch']['x1_ids'], x1_len: rets_data['batch']['x1_len'], x1x4_ids: rets_data['batch']['x1x4_ids'], x1x4_len: rets_data['batch']['x1x4_len'], sampled_story: new_in_sample_ids, sampled_story_len: new_in_sample_len, tau: config_train.tau, tx.global_mode(): tf.estimator.ModeKeys.TRAIN, reward: np.array(reward_sc) } fetches = { 'train_op': train_op, 'step': global_step, } fetches.update(loss_dict) rets = sess.run(fetches, feed_dict, options=run_opts) step = rets['step'] dis_steps = config_train.display_steps if _is_head() and dis_steps > 0 and step % dis_steps == 0: _log_losses(rets, step) eval_steps = config_train.eval_steps if _is_head() and eval_steps > 0 and step % eval_steps == 0: _dev_epoch(sess, evaluate_func=evaluate_full) # not used sample_steps = config_train.sample_steps if _is_head( ) and sample_steps > 0 and step % sample_steps == 0: print('-----------testing-----------------') _test_epoch(sess, step=step) # not used ckpt_steps = config_train.checkpoint_steps if _is_head() and ckpt_steps > 0 and step % ckpt_steps == 0: ckpt_fn = os.path.join(output_dir, 'model.ckpt') ckpt_fn = saver.save(sess, ckpt_fn, global_step=step) _log('Checkpoint to {}'.format(ckpt_fn)) except tf.errors.OutOfRangeError: break def _dev_epoch(sess, evaluate_func=evaluate_full): """Evaluates on the dev set. """ dev_arc_file = [ i.strip().split() for i in open( os.path.join(config_train.arc_data, "dev_mapped.txt")) ] with open( os.path.join(config_train.tfrecord_data_dir, "x4_emo_features.dev"), 'rb') as fp: emotion_feats = np.array(pickle.load(fp)) iterator.restart_dataset(sess, 'dev') nsamples = 0 hypotheses = [] references = [] reward_score = [] losses = [] hypotheses_dict = {} while True: try: # (1) Get data and yy sample fetches_data = { 'batch': batch, 'batch_size': batch_size, } feed_dict_data = { iterator.handle: iterator.get_handle(sess, 'dev'), tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets_data = sess.run(fetches_data, feed_dict_data) # (2) eval loss feed_dict = { x1_ids: rets_data['batch']['x1_ids'], x1_len: rets_data['batch']['x1_len'], x1x4_ids: rets_data['batch']['x1x4_ids'], x1x4_len: rets_data['batch']['x1x4_len'], # x4_emo: rets_data['batch']['x4_emo'], tau: config_train.tau, tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } # rets_loss = sess.run(fetches, feed_dict) fetches = { 'loss_fine': loss_dict['loss_fine'], #'beam_search_ids': beam_search_ids, 'greedy_sym': symbols_gr, 'greedy_len': len_gr, 'target_ids': target_ids } rets = sess.run(fetches, feed_dict) losses.append(rets['loss_fine']) _, beam_text = _get_text(proc, rets['greedy_sym'], rets['greedy_len']) beam_story = format_generated_stories_for_clf( beam_text, FLAGS.rl_method) _, target_text = _get_text(proc, rets['target_ids'], rets_data['batch']['x1x4_len']) hypotheses.extend(beam_text) references.extend(target_text) hypotheses_dict_ = generate_all_valid_sample_dict( predictor, rets_data['batch']['unique_id'], beam_story, method=FLAGS.rl_method) for key, react in hypotheses_dict_.items(): if key not in hypotheses_dict: hypotheses_dict[ key] = react # dictionary key=unique_id value =list of list nsamples += rets_data['batch_size'] except tf.errors.OutOfRangeError: break avg_loss = np.mean(losses) metrics = evaluate_func(references, hypotheses, hypotheses_dict, dev_arc_file, emotion_feats, method=FLAGS.rl_method) msg = 'loss_fine: %.4f, bleu: %.4f, meteor: %.4f, reward: %.4f' % \ (avg_loss, metrics['bleu'], metrics['meteor'], metrics["best_reward"] ) _log('nsamples validation: %d' % nsamples) _log(msg) if FLAGS.best_model == "emotion": if FLAGS.do_train and metrics["best_reward"] > dev_best[ 'best_reward']: # dev_best.update(results.avg()) dev_best['loss_fine'] = avg_loss dev_best['best_reward'] = metrics["best_reward"] dev_best.update(metrics) ckpt_fn = os.path.join(output_dir, 'model_best.ckpt') ckpt_fn = saver_best.save(sess, ckpt_fn) _log('Checkpoint best to {}'.format(ckpt_fn)) elif FLAGS.best_model == "bleu": if FLAGS.do_train and metrics["bleu"] > dev_best['bleu']: # dev_best.update(results.avg()) dev_best['loss_fine'] = avg_loss dev_best['best_reward'] = metrics["best_reward"] dev_best.update(metrics) ckpt_fn = os.path.join(output_dir, 'model_best.ckpt') ckpt_fn = saver_best.save(sess, ckpt_fn) _log('Checkpoint best to {}'.format(ckpt_fn)) elif FLAGS.do_train and avg_loss < dev_best['loss']: # dev_best.update(results.avg()) dev_best['loss_fine'] = avg_loss dev_best.update(metrics) dev_best['best_reward'] = metrics["best_reward"] ckpt_fn = os.path.join(output_dir, 'model_best.ckpt') ckpt_fn = saver_best.save(sess, ckpt_fn) _log('Checkpoint best to {}'.format(ckpt_fn)) def _test_epoch(sess, step=None): """Generates samples on the test set. """ iterator.restart_dataset(sess, 'test') _all_inputs = [] _all_samples = [] if FLAGS.finetune: _log('Generation input: x1') fetches = { 'inputs': batch['x1_ids'], 'length': batch['x1_len'], 'samples_length': x4_len_fine, 'samples': x4_ids_fine } res_fn_appendix = "x1" while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'test'), tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets = sess.run(fetches, feed_dict=feed_dict) _inputs = [] for i, l in zip(rets['inputs'], rets['length']): # Delete padding _inputs.append(i[:l].tolist()) _all_inputs.extend(_inputs) _samples = [] for s, l in zip(rets['samples'], rets['samples_length']): _samples.append(s[:l].tolist( )) # rets['samples'] are np array [bs, max_seq_len=200] _all_samples.extend(_samples) except tf.errors.OutOfRangeError: break # Parse samples and write to file eos_token_id = proc.encoder['<|endoftext|>'] _all_input_text = [] for i in _all_inputs: if i[0] == eos_token_id: i = i[1:] i_text = proc.decode(i) _all_input_text.append(i_text) _all_input_text = tx.utils.strip_eos(_all_input_text, eos_token='<|endoftext|>') _all_samples_text = [] for i, s in zip(_all_inputs, _all_samples): s_text = proc.decode(s) s_text = s_text.strip(" |").replace('\n', ' ') _all_samples_text.append(s_text) _all_samples_text = tx.utils.strip_eos(_all_samples_text, eos_token='<|endoftext|>') if step is None: fn = "test_samples_%s.tsv" % res_fn_appendix else: fn = "test_samples_%s_%d.tsv" % (res_fn_appendix, step) output_file = os.path.join(output_dir, fn) _log('Write samples to {}'.format(output_file)) tx.utils.write_paired_text(_all_input_text, _all_samples_text, output_file) # Broadcasts global variables from rank-0 process if FLAGS.distributed: bcast = hvd.broadcast_global_variables(0) session_config = tf.ConfigProto() if FLAGS.distributed: session_config.gpu_options.visible_device_list = str(hvd.local_rank()) session_config.gpu_options = tf.GPUOptions(allow_growth=True) with tf.Session(config=session_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) #smry_writer = tf.summary.FileWriter(FLAGS.output_dir, graph=sess.graph) if FLAGS.distributed: bcast.run() #Restores trained model if specified if FLAGS.checkpoint: _log('Restore from {}'.format(FLAGS.checkpoint)) saver.restore(sess, FLAGS.checkpoint) elif FLAGS.pretrain_checkpoint: _log('Restore from {}'.format(FLAGS.pretrain_checkpoint)) model_utils.init_gpt2_checkpoint(sess, FLAGS.pretrain_checkpoint) print("\nFinished loading\n") saver.save(sess, output_dir + '/gpt2_model.ckpt') iterator.initialize_dataset(sess) if FLAGS.do_train: for epoch in range(config_train.max_train_epoch): print("Training epoch {}".format(epoch)) _train_epoch(sess, epoch == 0) saver.save(sess, output_dir + '/model.ckpt') if FLAGS.do_eval: _dev_epoch(sess) if FLAGS.do_test: _test_epoch(sess)
def main(_): """ Builds the model and runs """ if FLAGS.distributed: import horovod.tensorflow as hvd hvd.init() tf.logging.set_verbosity(tf.logging.INFO) if len(config_train.name) > 0: output_dir = os.path.join(FLAGS.output_dir, config_train.name) else: output_dir = FLAGS.output_dir tx.utils.maybe_create_dir(output_dir) ## Loads GPT-2 model configuration if FLAGS.config_type == "json": gpt2_config = model_utils.transform_gpt2_to_texar_config( FLAGS.config_model) elif FLAGS.config_type == 'texar': gpt2_config = importlib.import_module( FLAGS.config_model) else: raise ValueError('Unknown config_type.') # Creates a data pre-processor for, e.g., BPE encoding proc = processor.get_encoder(FLAGS.pretrained_model_dir) max_decoding_length = config_train.max_decoding_length assert max_decoding_length <= gpt2_config.position_size, ( "max_decoding_length should not be greater than position_size. " "{}>{}".format(max_decoding_length, gpt2_config.position_size)) ## Loads data # Configures training data shard in distribued mode if FLAGS.distributed: config_train.train_hparam["dataset"]["num_shards"] = hvd.size() config_train.train_hparam["dataset"]["shard_id"] = hvd.rank() config_train.train_hparam["batch_size"] //= hvd.size() datasets = {} #if FLAGS.do_train: train_dataset = tx.data.TFRecordData(hparams=config_train.train_hparam) datasets['train'] = train_dataset #if FLAGS.do_eval: dev_dataset = tx.data.TFRecordData(hparams=config_train.dev_hparam) datasets['dev'] = dev_dataset #if FLAGS.do_test: test_dataset = tx.data.TFRecordData(hparams=config_train.test_hparam) datasets['test'] = test_dataset iterator = tx.data.FeedableDataIterator(datasets) batch = iterator.get_next() batch_size = tf.shape(batch['x1x4_ids'])[0] ## Builds the GPT-2 model vocab_size = gpt2_config.vocab_size word_embedder = tx.modules.WordEmbedder( vocab_size=vocab_size, hparams=gpt2_config.embed) pos_embedder = tx.modules.PositionEmbedder( position_size=gpt2_config.position_size, hparams=gpt2_config.pos_embed) # Ties output layer with input word embedding output_layer = tf.transpose(word_embedder.embedding, (1, 0)) decoder = tx.modules.TransformerDecoder( vocab_size=vocab_size, output_layer=output_layer, hparams=gpt2_config.decoder) # For training def _get_recon_loss(ids, full_len, prefix_len, mask_prefix=True, do_print=False): ids = ids[:,:tf.reduce_max(full_len)] batch_size__ = tf.shape(ids)[0] seq_len = tf.fill([batch_size__], tf.shape(ids)[1]) pos_embeds = pos_embedder(sequence_length=seq_len) input_embeds = word_embedder(ids) + pos_embeds outputs = decoder(inputs=input_embeds, decoding_strategy='train_greedy') max_full_len = tf.reduce_max(full_len) ids = ids[:, :max_full_len] logits = outputs.logits[:, :max_full_len] if mask_prefix: loss_recon = tx.losses.sequence_sparse_softmax_cross_entropy( labels=ids[:, 1:], logits=logits[:, :-1, :], sequence_length=full_len-1, average_across_timesteps=False, sum_over_timesteps=False, average_across_batch=False, sum_over_batch=False) mask_recon = tf.sequence_mask( full_len-1, dtype=tf.float32) mask_recon_prefix = 1 - tf.sequence_mask( prefix_len-1, maxlen=max_full_len-1,#max_decoding_length-1, dtype=tf.float32) mask_recon = mask_recon * mask_recon_prefix if do_print: print_op_1 = tf.print(mask_recon) loss_recon_flat = tx.utils.reduce_with_weights( tensor=loss_recon, weights=mask_recon, average_across_remaining=False, sum_over_remaining=False, average_across_batch=False) print_op_2 = tf.print(loss_recon_flat) with tf.control_dependencies([print_op_1, print_op_2]): loss_recon = tx.utils.reduce_with_weights( tensor=loss_recon, weights=mask_recon, average_across_remaining=True, sum_over_remaining=False) return loss_recon, mask_recon, loss_recon_flat else: loss_recon = tx.utils.reduce_with_weights( tensor=loss_recon, weights=mask_recon, average_across_remaining=True, sum_over_remaining=False) else: loss_recon = tx.losses.sequence_sparse_softmax_cross_entropy( labels=ids[:, 1:], logits=logits[:, :-1, :], sequence_length=full_len-1, average_across_timesteps=True, sum_over_timesteps=False, average_across_batch=True, sum_over_batch=False) return loss_recon ## ROC Loss-1: fine-tune loss x1_len = tf.placeholder(tf.int32, shape=[None], name='x1_len') x1x4_ids = tf.placeholder(tf.int32, shape=[None, None], name='x1x4_ids') x1x4_len = tf.placeholder(tf.int32, shape=[None], name='x1x4_len') loss_fine = _get_recon_loss(x1x4_ids, x1x4_len, x1_len) tau = tf.placeholder(tf.float32, shape=[], name='tau') # generate soft yy def _soft_embedding_fn(soft_ids, times): return word_embedder(soft_ids=soft_ids) + pos_embedder(times) end_token = proc.encoder['<|endoftext|>'] if not FLAGS.supervised: loss = config_train.w_fine * loss_fine loss_dict = { 'loss': loss, 'loss_fine': config_train.w_fine * loss_fine, } else: loss = loss_yy loss_dict = { 'loss': loss, 'loss_yy': loss_yy, # dumb 'loss_mask_recon': tf.constant(0), 'loss_bt': tf.constant(0), 'loss_d_xx2': tf.constant(0), 'loss_d_x2': tf.constant(0), 'loss_fine': tf.constant(0), 'loss_xx2': tf.constant(0) } ## Inference def _embedding_fn(ids, times): return word_embedder(ids) + pos_embedder(times) def _infer(context_name): helper = tx.modules.TopKSampleEmbeddingHelper( embedding=_embedding_fn, start_tokens=batch['%s_ids' % context_name][:, 0], end_token=end_token, top_k=FLAGS.top_k, softmax_temperature=FLAGS.temperature) outputs_infer, len_infer = decoder( context=batch['%s_ids' % context_name], context_sequence_length=batch['%s_len' % context_name], max_decoding_length=max_decoding_length, helper=helper) yy_ids = tx.utils.varlength_roll( outputs_infer.sample_id, -batch['%s_len' % context_name]) yy_len = len_infer - batch['%s_len' % context_name] yy_ids = yy_ids[:, :tf.reduce_max(yy_len)] # yy_logits = outputs_infer.logits # # yy_loss = _evaluate_loss_test(yy_logits, target_name, context_name) return yy_ids, yy_len def _evaluate_loss_test(target_name, context_name, bpe_loss=FLAGS.bpe_loss): ids = batch['%s_ids' % target_name] full_len = batch['%s_len' % target_name] ids = ids[:, :tf.reduce_max(full_len)] batch_size__ = tf.shape(ids)[0] seq_len = tf.fill([batch_size__], tf.shape(ids)[1]) pos_embeds = pos_embedder(sequence_length=seq_len) input_embeds = word_embedder(ids) + pos_embeds # greedy output outputs = decoder(inputs=input_embeds, decoding_strategy='train_greedy') max_full_len = tf.reduce_max(full_len) logits = outputs.logits[:, :max_full_len] test_loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=ids[:, 1:], logits=logits[:, :-1, :], sequence_length=full_len - 1, average_across_timesteps=False, sum_over_timesteps=False, # not bpe_loss, # True, average_across_batch=False, sum_over_batch=False) mask_recon = tf.sequence_mask( full_len - 1, dtype=tf.float32) mask_recon_prefix = 1 - tf.sequence_mask( batch['%s_len' % context_name] - 1, maxlen=max_full_len - 1, # max_decoding_length-1, dtype=tf.float32) mask_recon = mask_recon * mask_recon_prefix test_loss = tx.utils.reduce_with_weights( tensor=test_loss, weights=mask_recon, average_across_batch=bpe_loss, average_across_remaining=bpe_loss, sum_over_remaining=not bpe_loss) return test_loss # [bs,] ? x4_ids_fine, x4_len_fine = _infer('x1') x4_loss_fine = _evaluate_loss_test('x1x4', 'x1') ## Optimization def _get_beam_ids(context_name): # beam-search predictions = decoder( beam_width=5, length_penalty=config_train.length_penalty, embedding=_embedding_fn, context=batch['%s_ids' % context_name], context_sequence_length=batch['%s_len' % context_name], max_decoding_length=max_decoding_length, end_token=end_token, mode=tf.estimator.ModeKeys.PREDICT) beam_output_ids = tx.utils.varlength_roll(predictions["sample_id"][:, :, 0], -batch['%s_len' % context_name]) return beam_output_ids beam_search_ids = _get_beam_ids('x1') def _get_greedy_story(context_name): greedy_res, greedy_len = decoder( decoding_strategy='infer_greedy', embedding=_embedding_fn, context=batch['%s_ids' % context_name], context_sequence_length=batch['%s_len' % context_name], max_decoding_length=max_decoding_length, end_token=end_token, mode=tf.estimator.ModeKeys.PREDICT) greedy_ids = tx.utils.varlength_roll(greedy_res.sample_id, -batch['%s_len' % context_name]) greedy_ids_len = greedy_len - batch['%s_len' % context_name] greedy_ids = greedy_ids[:, :tf.reduce_max(greedy_ids_len)] return greedy_ids, greedy_ids_len greedy_ids, greedy_len = _get_greedy_story('x1') trainable_variables = tx.utils.collect_trainable_variables( [word_embedder, pos_embedder, decoder]) global_step = tf.Variable(0, trainable=False) opt = tx.core.get_optimizer( global_step=global_step, hparams=config_train.opt) if FLAGS.distributed: opt = hvd.DistributedOptimizer(opt) train_op = tf.contrib.layers.optimize_loss( loss=loss, global_step=global_step, learning_rate=None, optimizer=opt, variables=trainable_variables) ## Train/eval/test routine saver = tf.train.Saver() saver_best = tf.train.Saver(max_to_keep=1) dev_best = { 'loss': 1e8, 'loss_fine': 1e8} def _log_losses(losses, step=None): loss_str = 'loss: %.4f, loss_fine: %.4f' % \ (losses['loss'], losses['loss_fine']) if step is not None: loss_str = 'step: %d, %s' % (step, loss_str) _log(loss_str) def _is_head(): if not FLAGS.distributed: return True else: return hvd.rank() == 0 def _train_epoch(sess, initial=False): """Trains on the training set, and evaluates on the dev set periodically. """ iterator.restart_dataset(sess, 'train') while True: try: # (1) Get data and yy sample fetches_data = { 'batch': batch, 'batch_size': batch_size, } feed_dict_data = { iterator.handle: iterator.get_handle(sess, 'train'), tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets_data = sess.run(fetches_data, feed_dict_data) # (2) Optimize loss feed_dict = { #x1_ids: rets_data['batch']['x1_ids'], x1_len: rets_data['batch']['x1_len'], x1x4_ids: rets_data['batch']['x1x4_ids'], x1x4_len: rets_data['batch']['x1x4_len'], tau: config_train.tau, tx.global_mode(): tf.estimator.ModeKeys.TRAIN, } fetches = { 'train_op': train_op, 'step': global_step, } fetches.update(loss_dict) rets = sess.run(fetches, feed_dict) step = rets['step'] dis_steps = config_train.display_steps if _is_head() and dis_steps > 0 and step % dis_steps == 0: _log_losses(rets, step) eval_steps = config_train.eval_steps if _is_head() and eval_steps > 0 and step % eval_steps == 0: _dev_epoch(sess) sample_steps = config_train.sample_steps if _is_head() and sample_steps > 0 and step % sample_steps == 0: print('-----------testing-----------------') _test_epoch(sess, step=step) ckpt_steps = config_train.checkpoint_steps if _is_head() and ckpt_steps > 0 and step % ckpt_steps == 0: ckpt_fn = os.path.join(output_dir, 'model.ckpt') ckpt_fn = saver.save(sess, ckpt_fn, global_step=step) _log('Checkpoint to {}'.format(ckpt_fn)) except tf.errors.OutOfRangeError: break def _dev_epoch(sess): """Evaluates on the dev set. """ iterator.restart_dataset(sess, 'dev') results = tx.utils.AverageRecorder() nsamples = 0 fetches = {} fetches.update(loss_dict) # i = 0 while True: try: # (1) Get data and yy sample fetches_data = { 'batch': batch, 'batch_size': batch_size, } feed_dict_data = { iterator.handle: iterator.get_handle(sess, 'dev'), tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets_data = sess.run(fetches_data, feed_dict_data) # (2) eval loss feed_dict = { #x1_ids: rets_data['batch']['x1_ids'], x1_len: rets_data['batch']['x1_len'], x1x4_ids: rets_data['batch']['x1x4_ids'], x1x4_len: rets_data['batch']['x1x4_len'], tau: config_train.tau, tx.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets = sess.run(fetches, feed_dict) results.add(rets, weight=rets_data['batch_size']) nsamples += rets_data['batch_size'] except tf.errors.OutOfRangeError: break _log_losses(results.avg()) _log('nsamples: %d' % nsamples) avg_loss = results.avg('loss') if FLAGS.do_train and avg_loss < dev_best['loss']: dev_best.update(results.avg()) ckpt_fn = os.path.join(output_dir, 'model_best.ckpt') ckpt_fn = saver_best.save(sess, ckpt_fn) _log('Checkpoint best to {}'.format(ckpt_fn)) def _test_epoch(sess, step=None): """Generates samples on the test set. """ iterator.restart_dataset(sess, 'test') _all_inputs = [] _all_samples = [] _all_loss = [] # if FLAGS.finetune and FLAGS.roc: # raise ValueError('Cannot set --finetune and --roc at the same time') if FLAGS.finetune: _log('Generation input: x1') if FLAGS.greedy: fetches = { 'inputs': batch['x1_ids'], 'length': batch['x1_len'], 'samples_length': greedy_len, 'samples': greedy_ids } elif FLAGS.beam: fetches = { 'inputs': batch['x1_ids'], 'length': batch['x1_len'], # 'samples_length': x4_len_fine, 'samples': beam_search_ids } else: fetches = { 'inputs': batch['x1_ids'], 'length': batch['x1_len'], 'samples_length': x4_len_fine, 'samples': x4_ids_fine, 'sample_loss': x4_loss_fine, 'outputs': batch['x1x4_ids'], 'out_length': batch['x1x4_len'] } res_fn_appendix = "x1" while True: try: feed_dict = { iterator.handle: iterator.get_handle(sess, 'test'), tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT, } rets = sess.run(fetches, feed_dict=feed_dict) # ! ---- _inputs = [] for i, l in zip(rets['inputs'], rets['length']): # Delete padding _inputs.append(i[:l].tolist()) _all_inputs.extend(_inputs) _samples = [] if not FLAGS.beam: for s, l in zip(rets['samples'], rets['samples_length']): _samples.append(s[:l].tolist()) else: _samples.extend(h.tolist() for h in rets['samples']) _samples = utils.list_strip_eos(_samples, eos_token=proc.encoder['<|endoftext|>']) _all_samples.extend(_samples) # ----! _loss = [] if not FLAGS.bpe_loss: for los in rets["sample_loss"]: _loss.append(los) else: _loss = [rets["sample_loss"]] _all_loss.extend(_loss) except tf.errors.OutOfRangeError: break # Parse samples and write to file eos_token_id = proc.encoder['<|endoftext|>'] # !---- _all_input_text = [] for i in _all_inputs: if i[0] == eos_token_id: i = i[1:] i_text = proc.decode(i) _all_input_text.append(i_text) _all_input_text = tx.utils.strip_eos(_all_input_text, eos_token='<|endoftext|>') _all_samples_text = [] for j, (i, s) in enumerate(zip(_all_inputs, _all_samples)): s_text = proc.decode(s) s_text = s_text.replace('\n', ' ') # print(s_text) _all_samples_text.append(s_text) if j % 1000 == 0: print("{} stories is process of total {}".format(j, len(_all_inputs))) _all_samples_text = tx.utils.strip_eos(_all_samples_text, eos_token='<|endoftext|>') if step is None: fn = "test_samples_%s_sample_k%d.tsv" % (res_fn_appendix, FLAGS.top_k) else: fn = "test_samples_%s_%d_beam.tsv" % (res_fn_appendix, step) output_file = os.path.join(output_dir, fn) _log('Write samples to {}'.format(output_file)) if not FLAGS.beam: tx.utils.write_paired_text( _all_input_text, _all_samples_text, output_file) with open(output_file[:-4]+".txt", 'w') as f: for item in _all_samples_text: f.write("%s\n" % item.strip(" | ")) else: with open(output_file, 'w') as f: for item in _all_samples_text: f.write("%s\n" % item) # ----! if FLAGS.ppl: if not FLAGS.bpe_loss: # load target file target = [i.strip().split() for i in open("emotion_evaluation/baselines/ground-truth/ground_truth_story-processed.txt")] for j, (txt, los) in enumerate(zip(target, _all_loss)): _all_loss[j] = los/len(txt) np.save(os.path.join(output_dir, "test_loss_word.npy"), np.array(_all_loss)) avg_loss = np.mean(np.array(_all_loss)) ppl = np.exp(avg_loss) msg = 'test_loss (per word): %.4f, test_perplexity: %.4f' % \ (avg_loss, ppl ) else: avg_loss = np.mean(np.array(_all_loss)) ppl = np.exp(avg_loss) msg = 'test_loss (bpe): %.4f, test_perplexity: %.4f' % \ (avg_loss, ppl ) _log(msg) # Broadcasts global variables from rank-0 process if FLAGS.distributed: bcast = hvd.broadcast_global_variables(0) session_config = tf.ConfigProto() if FLAGS.distributed: session_config.gpu_options.visible_device_list = str(hvd.local_rank()) with tf.Session(config=session_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) # smry_writer = tf.summary.FileWriter(FLAGS.output_dir, graph=sess.graph) if FLAGS.distributed: bcast.run() #Restores trained model if specified if FLAGS.checkpoint: _log('Restore from {}'.format(FLAGS.checkpoint)) saver.restore(sess, FLAGS.checkpoint) elif FLAGS.pretrain_checkpoint: _log('Restore from {}'.format(FLAGS.pretrain_checkpoint)) model_utils.init_gpt2_checkpoint(sess, FLAGS.pretrain_checkpoint) print("\nFinished loading\n") saver.save(sess, output_dir + '/gpt2_model.ckpt') iterator.initialize_dataset(sess) if FLAGS.do_train: for epoch in range(config_train.max_train_epoch): print("Training epoch {}".format(epoch)) _train_epoch(sess, epoch==0) saver.save(sess, output_dir + '/model.ckpt') if FLAGS.do_eval: _dev_epoch(sess) if FLAGS.do_test: _test_epoch(sess)