class BeamSearch(object): def __init__(self, model_file_path): model_name = os.path.basename(model_file_path) self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name)) self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]: if not os.path.exists(p): os.mkdir(p) self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode', batch_size=config.beam_size, single_pass=True) time.sleep(15) self.model = Model(model_file_path, is_eval=True) def sort_beams(self, beams): return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) def decode(self): start = time.time() counter = 0 batch = self.batcher.next_batch() while batch is not None: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract = batch.original_abstracts_sents[0] write_for_rouge(original_abstract, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() print("Decoder has finished reading dataset for single_pass.") print("Now starting ROUGE eval...") results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) rouge_log(results_dict, self._decode_dir) def beam_search(self, batch): #batch should have only one example enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \ get_input_from_batch(batch, use_cuda) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_0 = self.model.reduce_state(encoder_hidden) dec_h, dec_c = s_t_0 # 1 x 2*hidden_size dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() #decoder batch preparation, it has beam_size example initially everything is repeated beams = [ Beam(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0], state=(dec_h[0], dec_c[0]), context=c_t_0[0], coverage=(coverage_t_0[0] if config.is_coverage else None)) for _ in range(config.beam_size) ] results = [] steps = 0 while steps < config.max_dec_steps and len(results) < config.beam_size: latest_tokens = [h.latest_token for h in beams] latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \ for t in latest_tokens] y_t_1 = Variable(torch.LongTensor(latest_tokens)) if use_cuda: y_t_1 = y_t_1.cuda() all_state_h = [] all_state_c = [] all_context = [] for h in beams: state_h, state_c = h.state all_state_h.append(state_h) all_state_c.append(state_c) all_context.append(h.context) s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) c_t_1 = torch.stack(all_context, 0) coverage_t_1 = None if config.is_coverage: all_coverage = [] for h in beams: all_coverage.append(h.coverage) coverage_t_1 = torch.stack(all_coverage, 0) final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps) topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2) dec_h, dec_c = s_t dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() all_beams = [] num_orig_beams = 1 if steps == 0 else len(beams) for i in range(num_orig_beams): h = beams[i] state_i = (dec_h[i], dec_c[i]) context_i = c_t[i] coverage_i = (coverage_t[i] if config.is_coverage else None) for j in range(config.beam_size * 2): # for each of the top 2*beam_size hyps: new_beam = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item(), state=state_i, context=context_i, coverage=coverage_i) all_beams.append(new_beam) beams = [] for h in self.sort_beams(all_beams): if h.latest_token == self.vocab.word2id(data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: beams.append(h) if len(beams) == config.beam_size or len( results) == config.beam_size: break steps += 1 if len(results) == 0: results = beams beams_sorted = self.sort_beams(results) return beams_sorted[0]
def main(unused_argv): print("unused_argv: ", unused_argv) if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) tf.logging.set_verbosity( tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if not os.path.exists(FLAGS.log_root): if FLAGS.mode == "train": os.makedirs(FLAGS.log_root) else: raise Exception( "Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) print("FLAGS.vocab_size: ", FLAGS.vocab_size) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary print("vocab size: ", vocab.size()) # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode != 'decode': raise Exception( "The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = [ 'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'fine_tune', 'train_size', 'subred_size', 'use_doc_vec', 'use_multi_attn', 'use_multi_pgen', 'use_multi_pvocab', 'create_ckpt' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) # Create a batcher object that will create minibatches of data batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass) tf.set_random_seed(111) # a seed value for randomness # return if hps.mode.value == 'train': print("creating model...") model = SummarizationModel(hps, vocab) # ------------------------------------- if hps.create_ckpt.value: step = 0 model.build_graph() print("get value") pretrained_ckpt = '/home/cs224u/pointer/log/pretrained_model_tf1.2.1/train/model-238410' reader = pywrap_tensorflow.NewCheckpointReader(pretrained_ckpt) var_to_shape_map = reader.get_variable_to_shape_map() value = {} for key in var_to_shape_map: value[key] = reader.get_tensor(key) print("assign op") assign_op = [] if hps.use_multi_pvocab.value: new_key = [ "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_0/Bias", "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_1/Bias" ] for v in tf.trainable_variables(): key = v.name.split(":")[0] if key in new_key: origin_key = "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear/" + key.split( "/")[-1] a_op = v.assign(tf.convert_to_tensor( value[origin_key])) else: a_op = v.assign(tf.convert_to_tensor(value[key])) # if key == "seq2seq/embedding/embedding": # a_op = v.assign(tf.convert_to_tensor(value[key])) assign_op.append(a_op) else: for v in tf.trainable_variables(): key = v.name.split(":")[0] if key == "seq2seq/embedding/embedding": a_op = v.assign(tf.convert_to_tensor(value[key])) assign_op.append(a_op) # ratio = 1 # for v in tf.trainable_variables(): # key = v.name.split(":")[0] # # embedding (50000, 128) -> (50000, 32) # if key == "seq2seq/embedding/embedding": # print (key) # print (value[key].shape) # d1 = value[key].shape[1] # a_op = v.assign(tf.convert_to_tensor(value[key][:,:d1//ratio])) # # kernel (384, 1024) -> (96, 256) # # w_reduce_c (512, 256) -> (128, 64) # elif key == "seq2seq/encoder/bidirectional_rnn/fw/lstm_cell/kernel" or \ # key == "seq2seq/encoder/bidirectional_rnn/bw/lstm_cell/kernel" or \ # key == "seq2seq/reduce_final_st/w_reduce_c" or \ # key == "seq2seq/reduce_final_st/w_reduce_h" or \ # key == "seq2seq/decoder/attention_decoder/Linear/Matrix" or \ # key == "seq2seq/decoder/attention_decoder/lstm_cell/kernel" or \ # key == "seq2seq/decoder/attention_decoder/Attention/Linear/Matrix" or \ # key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear/Matrix": # print (key) # print (value[key].shape) # d0, d1 = value[key].shape[0], value[key].shape[1] # a_op = v.assign(tf.convert_to_tensor(value[key][:d0//ratio, :d1//ratio])) # # bias (1024,) -> (256,) # elif key == "seq2seq/encoder/bidirectional_rnn/fw/lstm_cell/bias" or \ # key == "seq2seq/encoder/bidirectional_rnn/bw/lstm_cell/bias" or \ # key == "seq2seq/reduce_final_st/bias_reduce_c" or \ # key == "seq2seq/reduce_final_st/bias_reduce_h" or \ # key == "seq2seq/decoder/attention_decoder/lstm_cell/bias" or \ # key == "seq2seq/decoder/attention_decoder/v" or \ # key == "seq2seq/decoder/attention_decoder/Attention/Linear/Bias" or \ # key == "seq2seq/decoder/attention_decoder/Linear/Bias" or \ # key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear/Bias": # print (key) # print (value[key].shape) # d0 = value[key].shape[0] # a_op = v.assign(tf.convert_to_tensor(value[key][:d0//ratio])) # # W_h (1, 1, 512, 512) -> (1, 1, 128, 128) # elif key == "seq2seq/decoder/attention_decoder/W_h": # print (key) # print (value[key].shape) # d2, d3 = value[key].shape[2], value[key].shape[3] # a_op = v.assign(tf.convert_to_tensor(value[key][:,:,:d2//ratio,:d3//ratio])) # # Matrix (1152, 1) -> (288, 1) # elif key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear/Matrix" or \ # key == "seq2seq/output_projection/w": # print (key) # print (value[key].shape) # d0 = value[key].shape[0] # a_op = v.assign(tf.convert_to_tensor(value[key][:d0//ratio,:])) # # Bias (1,) -> (1,) # elif key == "seq2seq/output_projection/v" or \ # key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear/Bias": # print (key) # print (value[key].shape) # a_op = v.assign(tf.convert_to_tensor(value[key])) # # multi_attn # if hps.use_multi_attn.value: # if key == "seq2seq/decoder/attention_decoder/attn_0/v" or \ # key == "seq2seq/decoder/attention_decoder/attn_1/v": # # key == "seq2seq/decoder/attention_decoder/attn_2/v": # k = "seq2seq/decoder/attention_decoder/v" # print (key) # print (value[k].shape) # d0 = value[k].shape[0] # a_op = v.assign(tf.convert_to_tensor(value[k][:d0//ratio])) # if key == "seq2seq/decoder/attention_decoder/Attention/Linear_0/Bias" or \ # key == "seq2seq/decoder/attention_decoder/Attention/Linear_1/Bias": # # key == "seq2seq/decoder/attention_decoder/Attention/Linear_2/Bias": # k = "seq2seq/decoder/attention_decoder/Attention/Linear/Bias" # print (key) # print (value[k].shape) # d0 = value[k].shape[0] # a_op = v.assign(tf.convert_to_tensor(value[k][:d0//ratio])) # elif hps.use_multi_pgen.value: # if key == "seq2seq/decoder/attention_decoder/Linear_0/Bias" or \ # key == "seq2seq/decoder/attention_decoder/Linear_1/Bias": # # key == "seq2seq/decoder/attention_decoder/Linear_2/Bias": # k = "seq2seq/decoder/attention_decoder/Linear/Bias" # print (key) # print (value[k].shape) # d0 = value[k].shape[0] # a_op = v.assign(tf.convert_to_tensor(value[k][:d0//ratio])) # if key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear_0/Bias" or \ # key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear_1/Bias": # # key == "seq2seq/decoder/attention_decoder/calculate_pgen/Linear_2/Bias": # k = "seq2seq/decoder/attention_decoder/calculate_pgen/Linear/Bias" # print (key) # print (value[k].shape) # a_op = v.assign(tf.convert_to_tensor(value[k])) # elif hps.use_multi_pvocab.value: # if key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_0/Bias" or \ # key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_1/Bias": # # key == "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear_2/Bias": # k = "seq2seq/decoder/attention_decoder/AttnOutputProjection/Linear/Bias" # print (key) # print (value[k].shape) # d0 = value[k].shape[0] # a_op = v.assign(tf.convert_to_tensor(value[k][:d0//ratio])) # assign_op.append(a_op) # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() with tf.Session(config=util.get_config()) as sess: sess.run(init_op) # Do some work with the model. for a_op in assign_op: a_op.op.run() for _ in range(0): batch = batcher.next_batch() results = model.run_train_step(sess, batch) # Save the variables to disk. if hps.use_multi_attn.value: ckpt_tag = "multi_attn_2_attn_proj" elif hps.use_multi_pgen.value: ckpt_tag = "multi_attn_2_pgen_proj" elif hps.use_multi_pvocab.value: ckpt_tag = "big_multi_attn_2_pvocab_proj" else: ckpt_tag = "pointer_proj" ckpt_to_save = '/home/cs224u/pointer/log/ckpt/' + ckpt_tag + '/model.ckpt-' + str( step) save_path = saver.save(sess, ckpt_to_save) print("Model saved in path: %s" % save_path) # ------------------------------------- else: setup_training(model, batcher, hps) elif hps.mode.value == 'eval': model = SummarizationModel(hps, vocab) run_eval(model, batcher, vocab) elif hps.mode.value == 'decode': decode_model_hps = hps # This will be the hyperparameters for the decoder model decode_model_hps = hps._replace( max_dec_steps=1 ) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries model = SummarizationModel(decode_model_hps, vocab) decoder = BeamSearchDecoder(model, batcher, vocab) decoder.decode( ) # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once) else: raise ValueError("The 'mode' flag must be one of train/eval/decode")
class Seq2Seq(object): def calc_running_avg_loss(self, loss, running_avg_loss, step, decay=0.99): """Calculate the running average loss via exponential decay. This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve. Args: loss: loss on the most recent eval step running_avg_loss: running_avg_loss so far summary_writer: FileWriter object to write for tensorboard step: training iteration step decay: rate of exponential decay, a float between 0 and 1. Larger is smoother. Returns: running_avg_loss: new running average loss """ if running_avg_loss == 0: # on the first iteration just take the loss running_avg_loss = loss else: running_avg_loss = running_avg_loss * decay + (1 - decay) * loss running_avg_loss = min(running_avg_loss, 12) # clip loss_sum = tf.Summary() tag_name = 'running_avg_loss/decay=%f' % (decay) loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss) self.summary_writer.add_summary(loss_sum, step) tf.logging.info('running_avg_loss: %f', running_avg_loss) return running_avg_loss def restore_best_model(self): """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory""" tf.logging.info("Restoring bestmodel for training...") # Initialize all vars in the model sess = tf.Session(config=util.get_config()) print("Initializing all variables...") sess.run(tf.initialize_all_variables()) # Restore the best model from eval dir saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name]) print("Restoring all non-adagrad variables from best model in eval dir...") curr_ckpt = util.load_ckpt(saver, sess, "eval") print("Restored %s." % curr_ckpt) # Save this model to train dir and quit new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model") new_fname = os.path.join(FLAGS.log_root, "train", new_model_name) print("Saving model to %s..." % (new_fname)) new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables new_saver.save(sess, new_fname) print("Saved.") exit() def restore_best_eval_model(self): # load best evaluation loss so far best_loss = None best_step = None # goes through all event files and select the best loss achieved and return it event_files = sorted(glob('{}/eval/events*'.format(FLAGS.log_root))) for ef in event_files: try: for e in tf.train.summary_iterator(ef): for v in e.summary.value: step = e.step if 'running_avg_loss/decay' in v.tag: running_avg_loss = v.simple_value if best_loss is None or running_avg_loss < best_loss: best_loss = running_avg_loss best_step = step except: continue tf.logging.info('resotring best loss from the current logs: {}\tstep: {}'.format(best_loss, best_step)) return best_loss def convert_to_coverage_model(self): """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" tf.logging.info("converting non-coverage model to coverage model..") # initialize an entire coverage model from scratch sess = tf.Session(config=util.get_config()) print("initializing everything...") sess.run(tf.global_variables_initializer()) # load all non-coverage weights from checkpoint saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name]) print("restoring non-coverage variables...") curr_ckpt = util.load_ckpt(saver, sess) print("restored.") # save this model and quit new_fname = curr_ckpt + '_cov_init' print("saving model to %s..." % (new_fname)) new_saver = tf.train.Saver() # this one will save all variables that now exist new_saver.save(sess, new_fname) print("saved.") exit() def convert_to_reinforce_model(self): """Load non-reinforce checkpoint, add initialized extra variables for reinforce, and save as new checkpoint""" tf.logging.info("converting non-reinforce model to reinforce model..") # initialize an entire reinforce model from scratch sess = tf.Session(config=util.get_config()) print("initializing everything...") sess.run(tf.global_variables_initializer()) # load all non-reinforce weights from checkpoint saver = tf.train.Saver([v for v in tf.global_variables() if "reinforce" not in v.name and "Adagrad" not in v.name]) print("restoring non-reinforce variables...") curr_ckpt = util.load_ckpt(saver, sess) print("restored.") # save this model and quit new_fname = curr_ckpt + '_rl_init' print("saving model to %s..." % (new_fname)) new_saver = tf.train.Saver() # this one will save all variables that now exist new_saver.save(sess, new_fname) print("saved.") exit() def setup_training(self): """Does setup before starting training (run_training)""" train_dir = os.path.join(FLAGS.log_root, "train") if not os.path.exists(train_dir): os.makedirs(train_dir) if FLAGS.ac_training: dqn_train_dir = os.path.join(FLAGS.log_root, "dqn", "train") if not os.path.exists(dqn_train_dir): os.makedirs(dqn_train_dir) #replaybuffer_pcl_path = os.path.join(FLAGS.log_root, "replaybuffer.pcl") #if not os.path.exists(dqn_target_train_dir): os.makedirs(dqn_target_train_dir) self.model.build_graph() # build the graph if FLAGS.convert_to_reinforce_model: assert (FLAGS.rl_training or FLAGS.ac_training), "To convert your pointer model to a reinforce model, run with convert_to_reinforce_model=True and either rl_training=True or ac_training=True" self.convert_to_reinforce_model() if FLAGS.convert_to_coverage_model: assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True" self.convert_to_coverage_model() if FLAGS.restore_best_model: self.restore_best_model() saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time # Loads pre-trained word-embedding. By default the model learns the embedding. if FLAGS.embedding: self.vocab.LoadWordEmbedding(FLAGS.embedding, FLAGS.emb_dim) word_vector = self.vocab.getWordEmbedding() self.sv = tf.train.Supervisor(logdir=train_dir, is_chief=True, saver=saver, summary_op=None, save_summaries_secs=60, # save summaries for tensorboard every 60 secs save_model_secs=60, # checkpoint every 60 secs global_step=self.model.global_step, init_feed_dict= {self.model.embedding_place:word_vector} if FLAGS.embedding else None ) self.summary_writer = self.sv.summary_writer self.sess = self.sv.prepare_or_wait_for_session(config=util.get_config()) if FLAGS.ac_training: tf.logging.info('DDQN building graph') t1 = time.time() # We create a separate graph for DDQN self.dqn_graph = tf.Graph() with self.dqn_graph.as_default(): self.dqn.build_graph() # build dqn graph tf.logging.info('building current network took {} seconds'.format(time.time()-t1)) self.dqn_target.build_graph() # build dqn target graph tf.logging.info('building target network took {} seconds'.format(time.time()-t1)) dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time self.dqn_sv = tf.train.Supervisor(logdir=dqn_train_dir, is_chief=True, saver=dqn_saver, summary_op=None, save_summaries_secs=60, # save summaries for tensorboard every 60 secs save_model_secs=60, # checkpoint every 60 secs global_step=self.dqn.global_step, ) self.dqn_summary_writer = self.dqn_sv.summary_writer self.dqn_sess = self.dqn_sv.prepare_or_wait_for_session(config=util.get_config()) ''' #### TODO: try loading a previously saved replay buffer # right now this doesn't work due to running DQN on a thread if os.path.exists(replaybuffer_pcl_path): tf.logging.info('Loading Replay Buffer...') try: self.replay_buffer = pickle.load(open(replaybuffer_pcl_path, "rb")) tf.logging.info('Replay Buffer loaded...') except: tf.logging.info('Couldn\'t load Replay Buffer file...') self.replay_buffer = ReplayBuffer(self.dqn_hps) else: self.replay_buffer = ReplayBuffer(self.dqn_hps) tf.logging.info("Building DDQN took {} seconds".format(time.time()-t1)) ''' self.replay_buffer = ReplayBuffer(self.dqn_hps) tf.logging.info("Preparing or waiting for session...") tf.logging.info("Created session.") try: self.run_training() # this is an infinite loop until interrupted except (KeyboardInterrupt, SystemExit): tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") self.sv.stop() if FLAGS.ac_training: self.dqn_sv.stop() def run_training(self): """Repeatedly runs training iterations, logging loss to screen and writing summaries""" tf.logging.info("Starting run_training") if FLAGS.debug: # start the tensorflow debugger self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess) self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) self.train_step = 0 if FLAGS.ac_training: # DDQN training is done asynchronously along with model training tf.logging.info('Starting DQN training thread...') self.dqn_train_step = 0 self.thrd_dqn_training = Thread(target=self.dqn_training) self.thrd_dqn_training.daemon = True self.thrd_dqn_training.start() watcher = Thread(target=self.watch_threads) watcher.daemon = True watcher.start() # starting the main thread tf.logging.info('Starting Seq2Seq training...') while True: # repeats until interrupted batch = self.batcher.next_batch() t0=time.time() if FLAGS.ac_training: # For DDQN, we first collect the model output to calculate the reward and Q-estimates # Then we fix the estimation either using our target network or using the true Q-values # This process will usually take time and we are working on improving it. transitions = self.model.collect_dqn_transitions(self.sess, batch, self.train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps) tf.logging.info('Q-values collection time: {}'.format(time.time()-t0)) # whenever we are working with the DDQN, we switch using DDQN graph rather than default graph with self.dqn_graph.as_default(): batch_len = len(transitions) # we use current decoder state to predict q_estimates, use_state_prime = False b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = False, max_art_oovs = batch.max_art_oovs) # we also get the next decoder state to correct the estimation, use_state_prime = True b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) # use current DQN to estimate values from current decoder state dqn_results = self.dqn.run_test_steps(sess=self.dqn_sess, x= b._x, return_best_action=True) q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size) dqn_best_action = dqn_results['best_action'] #dqn_q_estimate_loss = dqn_results['loss'] # use target DQN to estimate values for the next decoder state dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x= b_prime._x) q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size) # we need to expand the q_estimates to match the input batch max_art_oov # we use the q_estimate of UNK token for all the OOV tokens q_estimates = np.concatenate([q_estimates, np.reshape(q_estimates[:,0],[-1,1])*np.ones((len(transitions),batch.max_art_oovs))],axis=-1) # modify Q-estimates using the result collected from current and target DQN. # check algorithm 5 in the paper for more info: https://arxiv.org/pdf/1805.09461.pdf for i, tr in enumerate(transitions): if tr.done: q_estimates[i][tr.action] = tr.reward else: q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]] # use scheduled sampling to whether use true Q-values or DDQN estimation if FLAGS.dqn_scheduled_sampling: q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates) if not FLAGS.calculate_true_q: # when we are not training DDQN based on true Q-values, # we need to update Q-values in our transitions based on the q_estimates we collected from DQN current network. for trans, q_val in zip(transitions,q_estimates): trans.q_values = q_val # each have the size vocab_extended q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended) # Once we are done with modifying Q-values, we can use them to train the DDQN model. # In this paper, we use a priority experience buffer which always selects states with higher quality # to train the DDQN. The following line will add batch_size * max_dec_steps experiences to the replay buffer. # As mentioned before, the DDQN training is asynchronous. Therefore, once the related queues for DDQN training # are full, the DDQN will start the training. self.replay_buffer.add(transitions) # If dqn_pretrain flag is on, it means that we use a fixed Actor to only collect experiences for # DDQN pre-training if FLAGS.dqn_pretrain: tf.logging.info('RUNNNING DQN PRETRAIN: Adding data to relplay buffer only...') continue # if not, use the q_estimation to update the loss. results = self.model.run_train_steps(self.sess, batch, self.train_step, q_estimates) else: results = self.model.run_train_steps(self.sess, batch, self.train_step) t1=time.time() # get the summaries and iteration number so we can write summaries to tensorboard summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer self.train_step = results['global_step'] # we need this to update our running average loss tf.logging.info('seconds for training step {}: {}'.format(self.train_step, t1-t0)) printer_helper = {} printer_helper['pgen_loss']= results['pgen_loss'] if FLAGS.coverage: printer_helper['coverage_loss'] = results['coverage_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss'] else: printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['shared_loss'] = results['shared_loss'] printer_helper['rl_loss'] = results['rl_loss'] printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs'] if FLAGS.rl_training: printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values']) printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values']) printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r'] if FLAGS.ac_training: printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss)>0 else 0 for (k,v) in printer_helper.items(): if not np.isfinite(v): raise Exception("{} is not finite. Stopping.".format(k)) tf.logging.info('{}: {}\t'.format(k,v)) tf.logging.info('-------------------------------------------') self.summary_writer.add_summary(summaries, self.train_step) # write the summaries if self.train_step % 100 == 0: # flush the summary writer every so often self.summary_writer.flush() if FLAGS.ac_training: self.dqn_summary_writer.flush() if self.train_step > FLAGS.max_iter: break def dqn_training(self): """ training the DDQN network.""" try: while True: if self.dqn_train_step == FLAGS.dqn_pretrain_steps: raise SystemExit() _t = time.time() self.avg_dqn_loss = [] avg_dqn_target_loss = [] # Get a batch of size dqn_batch_size from replay buffer to train the model dqn_batch = self.replay_buffer.next_batch() if dqn_batch is None: tf.logging.info('replay buffer not loaded enough yet...') time.sleep(60) continue # Run train step for Current DQN model and collect the results dqn_results = self.dqn.run_train_steps(self.dqn_sess, dqn_batch) # Run test step for Target DQN model and collect the results and monitor the difference in loss between the two dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x=dqn_batch._x, y=dqn_batch._y, return_loss=True) self.dqn_train_step = dqn_results['global_step'] self.dqn_summary_writer.add_summary(dqn_results['summaries'], self.dqn_train_step) # write the summaries self.avg_dqn_loss.append(dqn_results['loss']) avg_dqn_target_loss.append(dqn_target_results['loss']) self.dqn_train_step = self.dqn_train_step + 1 tf.logging.info('seconds for training dqn model: {}'.format(time.time()-_t)) # UPDATING TARGET DDQN NETWORK WITH CURRENT MODEL with self.dqn_graph.as_default(): current_model_weights = self.dqn_sess.run([self.dqn.model_trainables])[0] # get weights of current model self.dqn_target.run_update_weights(self.dqn_sess, self.dqn_train_step, current_model_weights) # update target model weights with current model weights tf.logging.info('DQN loss at step {}: {}'.format(self.dqn_train_step, np.mean(self.avg_dqn_loss))) tf.logging.info('DQN Target loss at step {}: {}'.format(self.dqn_train_step, np.mean(avg_dqn_target_loss))) # sleeping is required if you want the keyboard interuption to work time.sleep(FLAGS.dqn_sleep_time) except (KeyboardInterrupt, SystemExit): tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") self.sv.stop() self.dqn_sv.stop() def watch_threads(self): """Watch example queue and batch queue threads and restart if dead.""" while True: time.sleep(60) if not self.thrd_dqn_training.is_alive(): # if the thread is dead tf.logging.error('Found DQN Learning thread dead. Restarting.') self.thrd_dqn_training = Thread(target=self.dqn_training) self.thrd_dqn_training.daemon = True self.thrd_dqn_training.start() def run_eval(self): """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" self.model.build_graph() # build the graph saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time sess = tf.Session(config=util.get_config()) if FLAGS.embedding: sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector}) eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved self.summary_writer = tf.summary.FileWriter(eval_dir) if FLAGS.ac_training: tf.logging.info('DDQN building graph') t1 = time.time() dqn_graph = tf.Graph() with dqn_graph.as_default(): self.dqn.build_graph() # build dqn graph tf.logging.info('building current network took {} seconds'.format(time.time()-t1)) self.dqn_target.build_graph() # build dqn target graph tf.logging.info('building target network took {} seconds'.format(time.time()-t1)) dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time dqn_sess = tf.Session(config=util.get_config()) dqn_train_step = 0 replay_buffer = ReplayBuffer(self.dqn_hps) running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping best_loss = self.restore_best_eval_model() # will hold the best loss achieved so far train_step = 0 while True: _ = util.load_ckpt(saver, sess) # load a new checkpoint if FLAGS.ac_training: _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint processed_batch = 0 avg_losses = [] # evaluate for 100 * batch_size before comparing the loss # we do this due to memory constraint, best to run eval on different machines with large batch size while processed_batch < 100*FLAGS.batch_size: processed_batch += FLAGS.batch_size batch = self.batcher.next_batch() # get the next batch if FLAGS.ac_training: t0 = time.time() transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps) tf.logging.info('Q values collection time: {}'.format(time.time()-t0)) with dqn_graph.as_default(): # if using true Q-value to train DQN network, # we do this as the pre-training for the DQN network to get better estimates batch_len = len(transitions) b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True) q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size) dqn_best_action = dqn_results['best_action'] tf.logging.info('running test step on dqn_target') dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x) q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size) # we need to expand the q_estimates to match the input batch max_art_oov q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1) tf.logging.info('fixing the action q-estimates') for i, tr in enumerate(transitions): if tr.done: q_estimates[i][tr.action] = tr.reward else: q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]] if FLAGS.dqn_scheduled_sampling: tf.logging.info('scheduled sampling on q-estimates') q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates) if not FLAGS.calculate_true_q: # when we are not training DQN based on true Q-values # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network. for trans, q_val in zip(transitions,q_estimates): trans.q_values = q_val # each have the size vocab_extended q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended) tf.logging.info('run eval step on seq2seq model.') t0=time.time() results = self.model.run_eval_step(sess, batch, train_step, q_estimates) t1=time.time() else: tf.logging.info('run eval step on seq2seq model.') t0=time.time() results = self.model.run_eval_step(sess, batch, train_step) t1=time.time() tf.logging.info('experiment: {}'.format(FLAGS.exp_name)) tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0)) printer_helper = {} loss = printer_helper['pgen_loss']= results['pgen_loss'] if FLAGS.coverage: printer_helper['coverage_loss'] = results['coverage_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss'] loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['shared_loss'] = results['shared_loss'] printer_helper['rl_loss'] = results['rl_loss'] printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs'] if FLAGS.rl_training: printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values']) printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values']) printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r'] if FLAGS.ac_training: printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss) > 0 else 0 for (k,v) in printer_helper.items(): if not np.isfinite(v): raise Exception("{} is not finite. Stopping.".format(k)) tf.logging.info('{}: {}\t'.format(k,v)) # add summaries summaries = results['summaries'] train_step = results['global_step'] self.summary_writer.add_summary(summaries, train_step) # calculate running avg loss avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, train_step)) tf.logging.info('-------------------------------------------') running_avg_loss = np.mean(avg_losses) tf.logging.info('==========================================') tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss)) tf.logging.info('==========================================') # If running_avg_loss is best so far, save this checkpoint (early stopping). # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir if best_loss is None or running_avg_loss < best_loss: tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path) saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') best_loss = running_avg_loss # flush the summary writer every so often if train_step % 100 == 0: self.summary_writer.flush() #time.sleep(600) # run eval every 10 minute def main(self, unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary flags = getattr(FLAGS,"__flags") if not os.path.exists(FLAGS.log_root): if FLAGS.mode=="train": os.makedirs(FLAGS.log_root) else: raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) fw = open('{}/config.txt'.format(FLAGS.log_root), 'w') for k, v in flags.items(): fw.write('{}\t{}\n'.format(k, v)) fw.close() self.vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode!='decode': raise Exception("The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = ['mode', 'lr', 'gpu_num', #'sampled_greedy_flag', 'gamma', 'eta', 'fixed_eta', 'reward_function', 'intradecoder', 'use_temporal_attention', 'ac_training','rl_training', 'matrix_attention', 'calculate_true_q', 'enc_hidden_dim', 'dec_hidden_dim', 'k', 'scheduled_sampling', 'sampling_probability','fixed_sampling_probability', 'alpha', 'hard_argmax', 'greedy_scheduled_sampling', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'dqn_scheduled_sampling', 'dqn_sleep_time', 'E2EBackProp', 'coverage', 'cov_loss_wt', 'pointer_gen'] hps_dict = {} for key,val in flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val.value # add it to the dict if FLAGS.ac_training: hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)}) self.hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) # creating all the required parameters for DDQN model. if FLAGS.ac_training: hparam_list = ['lr', 'dqn_gpu_num', 'dqn_layers', 'dqn_replay_buffer_size', 'dqn_batch_size', 'dqn_target_update', 'dueling_net', 'dqn_polyak_averaging', 'dqn_sleep_time', 'dqn_scheduled_sampling', 'max_grad_norm'] hps_dict = {} for key,val in flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val.value # add it to the dict hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)}) hps_dict.update({'vocab_size':self.vocab.size()}) self.dqn_hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) # Create a batcher object that will create minibatches of data self.batcher = Batcher(FLAGS.data_path, self.vocab, self.hps, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) tf.set_random_seed(111) # a seed value for randomness if self.hps.mode == 'train': print("creating model...") self.model = SummarizationModel(self.hps, self.vocab) if FLAGS.ac_training: # current DQN with paramters \Psi self.dqn = DQN(self.dqn_hps,'current') # target DQN with paramters \Psi^{\prime} self.dqn_target = DQN(self.dqn_hps,'target') self.setup_training() elif self.hps.mode == 'eval': self.model = SummarizationModel(self.hps, self.vocab) if FLAGS.ac_training: self.dqn = DQN(self.dqn_hps,'current') self.dqn_target = DQN(self.dqn_hps,'target') self.run_eval() elif self.hps.mode == 'decode': decode_model_hps = self.hps # This will be the hyperparameters for the decoder model decode_model_hps = self.hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries model = SummarizationModel(decode_model_hps, self.vocab) if FLAGS.ac_training: # We need our target DDQN network for collecting Q-estimation at each decoder step. dqn_target = DQN(self.dqn_hps,'target') else: dqn_target = None decoder = BeamSearchDecoder(model, self.batcher, self.vocab, dqn = dqn_target) decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once) else: raise ValueError("The 'mode' flag must be one of train/eval/decode") # Scheduled sampling used for either selecting true Q-estimates or the DDQN estimation # based on https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper def scheduled_sampling(self, batch_size, sampling_probability, true, estimate): with variable_scope.variable_scope("ScheduledEmbedding"): # Return -1s where we do not sample, and sample_ids elsewhere select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool) select_sample = select_sampler.sample(sample_shape=batch_size) sample_ids = array_ops.where( select_sample, tf.range(batch_size), gen_array_ops.fill([batch_size], -1)) where_sampling = math_ops.cast( array_ops.where(sample_ids > -1), tf.int32) where_not_sampling = math_ops.cast( array_ops.where(sample_ids <= -1), tf.int32) _estimate = array_ops.gather_nd(estimate, where_sampling) _true = array_ops.gather_nd(true, where_not_sampling) base_shape = array_ops.shape(true) result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape) result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape) result = result1 + result2 return result1 + result2
class Seq2Seq(object): def calc_running_avg_loss(self, loss, running_avg_loss, step, decay=0.99): """Calculate the running average loss via exponential decay. This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve. Args: loss: loss on the most recent eval step running_avg_loss: running_avg_loss so far summary_writer: FileWriter object to write for tensorboard step: training iteration step decay: rate of exponential decay, a float between 0 and 1. Larger is smoother. Returns: running_avg_loss: new running average loss """ if running_avg_loss == 0: # on the first iteration just take the loss running_avg_loss = loss else: running_avg_loss = running_avg_loss * decay + (1 - decay) * loss running_avg_loss = min(running_avg_loss, 12) # clip loss_sum = tf.Summary() tag_name = 'running_avg_loss/decay=%f' % (decay) loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss) self.summary_writer.add_summary(loss_sum, step) tf.logging.info('running_avg_loss: %f', running_avg_loss) return running_avg_loss def restore_best_model(self): """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory""" tf.logging.info("Restoring bestmodel for training...") # Initialize all vars in the model sess = tf.Session(config=util.get_config()) print "Initializing all variables..." sess.run(tf.initialize_all_variables()) # Restore the best model from eval dir saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name]) print "Restoring all non-adagrad variables from best model in eval dir..." curr_ckpt = util.load_ckpt(saver, sess, "eval") print "Restored %s." % curr_ckpt # Save this model to train dir and quit new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model") new_fname = os.path.join(FLAGS.log_root, "train", new_model_name) print "Saving model to %s..." % (new_fname) new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables new_saver.save(sess, new_fname) print "Saved." exit() def restore_best_eval_model(self): # load best evaluation loss so far best_loss = None best_step = None # goes through all event files and select the best loss achieved and return it event_files = sorted(glob('{}/eval/events*'.format(FLAGS.log_root))) for ef in event_files: try: for e in tf.train.summary_iterator(ef): for v in e.summary.value: step = e.step if 'running_avg_loss/decay' in v.tag: running_avg_loss = v.simple_value if best_loss is None or running_avg_loss < best_loss: best_loss = running_avg_loss best_step = step except: continue tf.logging.info('resotring best loss from the current logs: {}\tstep: {}'.format(best_loss, best_step)) return best_loss def convert_to_coverage_model(self): """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" tf.logging.info("converting non-coverage model to coverage model..") # initialize an entire coverage model from scratch sess = tf.Session(config=util.get_config()) print "initializing everything..." sess.run(tf.global_variables_initializer()) # load all non-coverage weights from checkpoint saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name]) print "restoring non-coverage variables..." curr_ckpt = util.load_ckpt(saver, sess) print "restored." # save this model and quit new_fname = curr_ckpt + '_cov_init' print "saving model to %s..." % (new_fname) new_saver = tf.train.Saver() # this one will save all variables that now exist new_saver.save(sess, new_fname) print "saved." exit() def convert_to_reinforce_model(self): """Load non-reinforce checkpoint, add initialized extra variables for reinforce, and save as new checkpoint""" tf.logging.info("converting non-reinforce model to reinforce model..") # initialize an entire reinforce model from scratch sess = tf.Session(config=util.get_config()) print "initializing everything..." sess.run(tf.global_variables_initializer()) # load all non-reinforce weights from checkpoint saver = tf.train.Saver([v for v in tf.global_variables() if "reinforce" not in v.name and "Adagrad" not in v.name]) print "restoring non-reinforce variables..." curr_ckpt = util.load_ckpt(saver, sess) print "restored." # save this model and quit new_fname = curr_ckpt + '_rl_init' print "saving model to %s..." % (new_fname) new_saver = tf.train.Saver() # this one will save all variables that now exist new_saver.save(sess, new_fname) print "saved." exit() def setup_training(self): """Does setup before starting training (run_training)""" train_dir = os.path.join(FLAGS.log_root, "train") if not os.path.exists(train_dir): os.makedirs(train_dir) if FLAGS.ac_training: dqn_train_dir = os.path.join(FLAGS.log_root, "dqn", "train") if not os.path.exists(dqn_train_dir): os.makedirs(dqn_train_dir) #replaybuffer_pcl_path = os.path.join(FLAGS.log_root, "replaybuffer.pcl") #if not os.path.exists(dqn_target_train_dir): os.makedirs(dqn_target_train_dir) self.model.build_graph() # build the graph if FLAGS.convert_to_reinforce_model: assert (FLAGS.rl_training or FLAGS.ac_training), "To convert your pointer model to a reinforce model, run with convert_to_reinforce_model=True and either rl_training=True or ac_training=True" self.convert_to_reinforce_model() if FLAGS.convert_to_coverage_model: assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True" self.convert_to_coverage_model() if FLAGS.restore_best_model: self.restore_best_model() saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time # Loads pre-trained word-embedding. By default the model learns the embedding. if FLAGS.embedding: self.vocab.LoadWordEmbedding(FLAGS.embedding, FLAGS.emb_dim) word_vector = self.vocab.getWordEmbedding() self.sv = tf.train.Supervisor(logdir=train_dir, is_chief=True, saver=saver, summary_op=None, save_summaries_secs=60, # save summaries for tensorboard every 60 secs save_model_secs=60, # checkpoint every 60 secs global_step=self.model.global_step, init_feed_dict= {self.model.embedding_place:word_vector} if FLAGS.embedding else None ) self.summary_writer = self.sv.summary_writer self.sess = self.sv.prepare_or_wait_for_session(config=util.get_config()) if FLAGS.ac_training: tf.logging.info('DDQN building graph') t1 = time.time() # We create a separate graph for DDQN self.dqn_graph = tf.Graph() with self.dqn_graph.as_default(): self.dqn.build_graph() # build dqn graph tf.logging.info('building current network took {} seconds'.format(time.time()-t1)) self.dqn_target.build_graph() # build dqn target graph tf.logging.info('building target network took {} seconds'.format(time.time()-t1)) dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time self.dqn_sv = tf.train.Supervisor(logdir=dqn_train_dir, is_chief=True, saver=dqn_saver, summary_op=None, save_summaries_secs=60, # save summaries for tensorboard every 60 secs save_model_secs=60, # checkpoint every 60 secs global_step=self.dqn.global_step, ) self.dqn_summary_writer = self.dqn_sv.summary_writer self.dqn_sess = self.dqn_sv.prepare_or_wait_for_session(config=util.get_config()) ''' #### TODO: try loading a previously saved replay buffer # right now this doesn't work due to running DQN on a thread if os.path.exists(replaybuffer_pcl_path): tf.logging.info('Loading Replay Buffer...') try: self.replay_buffer = pickle.load(open(replaybuffer_pcl_path, "rb")) tf.logging.info('Replay Buffer loaded...') except: tf.logging.info('Couldn\'t load Replay Buffer file...') self.replay_buffer = ReplayBuffer(self.dqn_hps) else: self.replay_buffer = ReplayBuffer(self.dqn_hps) tf.logging.info("Building DDQN took {} seconds".format(time.time()-t1)) ''' self.replay_buffer = ReplayBuffer(self.dqn_hps) tf.logging.info("Preparing or waiting for session...") tf.logging.info("Created session.") try: self.run_training() # this is an infinite loop until interrupted except (KeyboardInterrupt, SystemExit): tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") self.sv.stop() if FLAGS.ac_training: self.dqn_sv.stop() def run_training(self): """Repeatedly runs training iterations, logging loss to screen and writing summaries""" tf.logging.info("Starting run_training") if FLAGS.debug: # start the tensorflow debugger self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess) self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) self.train_step = 0 if FLAGS.ac_training: # DDQN training is done asynchronously along with model training tf.logging.info('Starting DQN training thread...') self.dqn_train_step = 0 self.thrd_dqn_training = Thread(target=self.dqn_training) self.thrd_dqn_training.daemon = True self.thrd_dqn_training.start() watcher = Thread(target=self.watch_threads) watcher.daemon = True watcher.start() # starting the main thread tf.logging.info('Starting Seq2Seq training...') while True: # repeats until interrupted batch = self.batcher.next_batch() t0=time.time() if FLAGS.ac_training: # For DDQN, we first collect the model output to calculate the reward and Q-estimates # Then we fix the estimation either using our target network or using the true Q-values # This process will usually take time and we are working on improving it. transitions = self.model.collect_dqn_transitions(self.sess, batch, self.train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps) tf.logging.info('Q-values collection time: {}'.format(time.time()-t0)) # whenever we are working with the DDQN, we switch using DDQN graph rather than default graph with self.dqn_graph.as_default(): batch_len = len(transitions) # we use current decoder state to predict q_estimates, use_state_prime = False b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = False, max_art_oovs = batch.max_art_oovs) # we also get the next decoder state to correct the estimation, use_state_prime = True b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) # use current DQN to estimate values from current decoder state dqn_results = self.dqn.run_test_steps(sess=self.dqn_sess, x= b._x, return_best_action=True) q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size) dqn_best_action = dqn_results['best_action'] #dqn_q_estimate_loss = dqn_results['loss'] # use target DQN to estimate values for the next decoder state dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x= b_prime._x) q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size) # we need to expand the q_estimates to match the input batch max_art_oov # we use the q_estimate of UNK token for all the OOV tokens q_estimates = np.concatenate([q_estimates, np.reshape(q_estimates[:,0],[-1,1])*np.ones((len(transitions),batch.max_art_oovs))],axis=-1) # modify Q-estimates using the result collected from current and target DQN. # check algorithm 5 in the paper for more info: https://arxiv.org/pdf/1805.09461.pdf for i, tr in enumerate(transitions): if tr.done: q_estimates[i][tr.action] = tr.reward else: q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]] # use scheduled sampling to whether use true Q-values or DDQN estimation if FLAGS.dqn_scheduled_sampling: q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates) if not FLAGS.calculate_true_q: # when we are not training DDQN based on true Q-values, # we need to update Q-values in our transitions based on the q_estimates we collected from DQN current network. for trans, q_val in zip(transitions,q_estimates): trans.q_values = q_val # each have the size vocab_extended q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended) # Once we are done with modifying Q-values, we can use them to train the DDQN model. # In this paper, we use a priority experience buffer which always selects states with higher quality # to train the DDQN. The following line will add batch_size * max_dec_steps experiences to the replay buffer. # As mentioned before, the DDQN training is asynchronous. Therefore, once the related queues for DDQN training # are full, the DDQN will start the training. self.replay_buffer.add(transitions) # If dqn_pretrain flag is on, it means that we use a fixed Actor to only collect experiences for # DDQN pre-training if FLAGS.dqn_pretrain: tf.logging.info('RUNNNING DQN PRETRAIN: Adding data to relplay buffer only...') continue # if not, use the q_estimation to update the loss. results = self.model.run_train_steps(self.sess, batch, self.train_step, q_estimates) else: results = self.model.run_train_steps(self.sess, batch, self.train_step) t1=time.time() # get the summaries and iteration number so we can write summaries to tensorboard summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer self.train_step = results['global_step'] # we need this to update our running average loss tf.logging.info('seconds for training step {}: {}'.format(self.train_step, t1-t0)) printer_helper = {} printer_helper['pgen_loss']= results['pgen_loss'] if FLAGS.coverage: printer_helper['coverage_loss'] = results['coverage_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss'] else: printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['shared_loss'] = results['shared_loss'] printer_helper['rl_loss'] = results['rl_loss'] printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs'] if FLAGS.rl_training: printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values']) printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values']) printer_helper['r_diff'] = printer_helper['sampled_r'] - printer_helper['greedy_r'] if FLAGS.ac_training: printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss)>0 else 0 for (k,v) in printer_helper.items(): if not np.isfinite(v): raise Exception("{} is not finite. Stopping.".format(k)) tf.logging.info('{}: {}\t'.format(k,v)) tf.logging.info('-------------------------------------------') self.summary_writer.add_summary(summaries, self.train_step) # write the summaries if self.train_step % 100 == 0: # flush the summary writer every so often self.summary_writer.flush() if FLAGS.ac_training: self.dqn_summary_writer.flush() if self.train_step > FLAGS.max_iter: break def dqn_training(self): """ training the DDQN network.""" try: while True: if self.dqn_train_step == FLAGS.dqn_pretrain_steps: raise SystemExit() _t = time.time() self.avg_dqn_loss = [] avg_dqn_target_loss = [] # Get a batch of size dqn_batch_size from replay buffer to train the model dqn_batch = self.replay_buffer.next_batch() if dqn_batch is None: tf.logging.info('replay buffer not loaded enough yet...') time.sleep(60) continue # Run train step for Current DQN model and collect the results dqn_results = self.dqn.run_train_steps(self.dqn_sess, dqn_batch) # Run test step for Target DQN model and collect the results and monitor the difference in loss between the two dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x=dqn_batch._x, y=dqn_batch._y, return_loss=True) self.dqn_train_step = dqn_results['global_step'] self.dqn_summary_writer.add_summary(dqn_results['summaries'], self.dqn_train_step) # write the summaries self.avg_dqn_loss.append(dqn_results['loss']) avg_dqn_target_loss.append(dqn_target_results['loss']) self.dqn_train_step = self.dqn_train_step + 1 tf.logging.info('seconds for training dqn model: {}'.format(time.time()-_t)) # UPDATING TARGET DDQN NETWORK WITH CURRENT MODEL with self.dqn_graph.as_default(): current_model_weights = self.dqn_sess.run([self.dqn.model_trainables])[0] # get weights of current model self.dqn_target.run_update_weights(self.dqn_sess, self.dqn_train_step, current_model_weights) # update target model weights with current model weights tf.logging.info('DQN loss at step {}: {}'.format(self.dqn_train_step, np.mean(self.avg_dqn_loss))) tf.logging.info('DQN Target loss at step {}: {}'.format(self.dqn_train_step, np.mean(avg_dqn_target_loss))) # sleeping is required if you want the keyboard interuption to work time.sleep(FLAGS.dqn_sleep_time) except (KeyboardInterrupt, SystemExit): tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") self.sv.stop() self.dqn_sv.stop() def watch_threads(self): """Watch example queue and batch queue threads and restart if dead.""" while True: time.sleep(60) if not self.thrd_dqn_training.is_alive(): # if the thread is dead tf.logging.error('Found DQN Learning thread dead. Restarting.') self.thrd_dqn_training = Thread(target=self.dqn_training) self.thrd_dqn_training.daemon = True self.thrd_dqn_training.start() def run_eval(self): """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" self.model.build_graph() # build the graph saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time sess = tf.Session(config=util.get_config()) if FLAGS.embedding: sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector}) eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved summary_writer = tf.summary.FileWriter(eval_dir) if FLAGS.ac_training: tf.logging.info('DDQN building graph') t1 = time.time() dqn_graph = tf.Graph() with dqn_graph.as_default(): self.dqn.build_graph() # build dqn graph tf.logging.info('building current network took {} seconds'.format(time.time()-t1)) self.dqn_target.build_graph() # build dqn target graph tf.logging.info('building target network took {} seconds'.format(time.time()-t1)) dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time dqn_sess = tf.Session(config=util.get_config()) dqn_train_step = 0 replay_buffer = ReplayBuffer(self.dqn_hps) running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping best_loss = self.restore_best_eval_model() # will hold the best loss achieved so far train_step = 0 while True: _ = util.load_ckpt(saver, sess) # load a new checkpoint if FLAGS.ac_training: _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint processed_batch = 0 avg_losses = [] # evaluate for 100 * batch_size before comparing the loss # we do this due to memory constraint, best to run eval on different machines with large batch size while processed_batch < 100*FLAGS.batch_size: processed_batch += FLAGS.batch_size batch = self.batcher.next_batch() # get the next batch if FLAGS.ac_training: t0 = time.time() transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps) tf.logging.info('Q values collection time: {}'.format(time.time()-t0)) with dqn_graph.as_default(): # if using true Q-value to train DQN network, # we do this as the pre-training for the DQN network to get better estimates batch_len = len(transitions) b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs) dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True) q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size) dqn_best_action = dqn_results['best_action'] tf.logging.info('running test step on dqn_target') dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x) q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size) # we need to expand the q_estimates to match the input batch max_art_oov q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1) tf.logging.info('fixing the action q-estimates') for i, tr in enumerate(transitions): if tr.done: q_estimates[i][tr.action] = tr.reward else: q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]] if FLAGS.dqn_scheduled_sampling: tf.logging.info('scheduled sampling on q-estimates') q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates) if not FLAGS.calculate_true_q: # when we are not training DQN based on true Q-values # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network. for trans, q_val in zip(transitions,q_estimates): trans.q_values = q_val # each have the size vocab_extended q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended) tf.logging.info('run eval step on seq2seq model.') t0=time.time() results = self.model.run_eval_step(sess, batch, train_step, q_estimates) t1=time.time() else: tf.logging.info('run eval step on seq2seq model.') t0=time.time() results = self.model.run_eval_step(sess, batch, train_step) t1=time.time() tf.logging.info('experiment: {}'.format(FLAGS.exp_name)) tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0)) printer_helper = {} loss = printer_helper['pgen_loss']= results['pgen_loss'] if FLAGS.coverage: printer_helper['coverage_loss'] = results['coverage_loss'] if FLAGS.rl_training or FLAGS.ac_training: loss = printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss'] else: loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss'] if FLAGS.rl_training or FLAGS.ac_training: printer_helper['shared_loss'] = results['shared_loss'] printer_helper['rl_loss'] = results['rl_loss'] printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs'] for (k,v) in printer_helper.items(): if not np.isfinite(v): raise Exception("{} is not finite. Stopping.".format(k)) tf.logging.info('{}: {}\t'.format(k,v)) # add summaries summaries = results['summaries'] train_step = results['global_step'] summary_writer.add_summary(summaries, train_step) # calculate running avg loss avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step)) tf.logging.info('-------------------------------------------') running_avg_loss = np.mean(avg_losses) tf.logging.info('==========================================') tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss)) tf.logging.info('==========================================') # If running_avg_loss is best so far, save this checkpoint (early stopping). # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir if best_loss is None or running_avg_loss < best_loss: tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path) saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') best_loss = running_avg_loss # flush the summary writer every so often if train_step % 100 == 0: summary_writer.flush() #time.sleep(600) # run eval every 10 minute def main(self, unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary flags = getattr(FLAGS,"__flags") if not os.path.exists(FLAGS.log_root): if FLAGS.mode=="train": os.makedirs(FLAGS.log_root) fw = open('{}/config.txt'.format(FLAGS.log_root),'w') for k,v in flags.iteritems(): fw.write('{}\t{}\n'.format(k,v)) fw.close() else: raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) self.vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode!='decode': raise Exception("The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = ['mode', 'lr', 'gpu_num', #'sampled_greedy_flag', 'gamma', 'eta', 'fixed_eta', 'reward_function', 'intradecoder', 'use_temporal_attention', 'ac_training','rl_training', 'matrix_attention', 'calculate_true_q', 'enc_hidden_dim', 'dec_hidden_dim', 'k', 'scheduled_sampling', 'sampling_probability','fixed_sampling_probability', 'alpha', 'hard_argmax', 'greedy_scheduled_sampling', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'dqn_scheduled_sampling', 'dqn_sleep_time', 'E2EBackProp', 'coverage', 'cov_loss_wt', 'pointer_gen'] hps_dict = {} for key,val in flags.iteritems(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict if FLAGS.ac_training: hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)}) self.hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) # creating all the required parameters for DDQN model. if FLAGS.ac_training: hparam_list = ['lr', 'dqn_gpu_num', 'dqn_layers', 'dqn_replay_buffer_size', 'dqn_batch_size', 'dqn_target_update', 'dueling_net', 'dqn_polyak_averaging', 'dqn_sleep_time', 'dqn_scheduled_sampling', 'max_grad_norm'] hps_dict = {} for key,val in flags.iteritems(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val # add it to the dict hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)}) hps_dict.update({'vocab_size':self.vocab.size()}) self.dqn_hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) # Create a batcher object that will create minibatches of data self.batcher = Batcher(FLAGS.data_path, self.vocab, self.hps, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after) tf.set_random_seed(111) # a seed value for randomness if self.hps.mode == 'train': print "creating model..." self.model = SummarizationModel(self.hps, self.vocab) if FLAGS.ac_training: # current DQN with paramters \Psi self.dqn = DQN(self.dqn_hps,'current') # target DQN with paramters \Psi^{\prime} self.dqn_target = DQN(self.dqn_hps,'target') self.setup_training() elif self.hps.mode == 'eval': self.model = SummarizationModel(self.hps, self.vocab) if FLAGS.ac_training: self.dqn = DQN(self.dqn_hps,'current') self.dqn_target = DQN(self.dqn_hps,'target') self.run_eval() elif self.hps.mode == 'decode': decode_model_hps = self.hps # This will be the hyperparameters for the decoder model decode_model_hps = self.hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries model = SummarizationModel(decode_model_hps, self.vocab) if FLAGS.ac_training: # We need our target DDQN network for collecting Q-estimation at each decoder step. dqn_target = DQN(self.dqn_hps,'target') else: dqn_target = None decoder = BeamSearchDecoder(model, self.batcher, self.vocab, dqn = dqn_target) decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once) else: raise ValueError("The 'mode' flag must be one of train/eval/decode") # Scheduled sampling used for either selecting true Q-estimates or the DDQN estimation # based on https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper def scheduled_sampling(self, batch_size, sampling_probability, true, estimate): with variable_scope.variable_scope("ScheduledEmbedding"): # Return -1s where we do not sample, and sample_ids elsewhere select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool) select_sample = select_sampler.sample(sample_shape=batch_size) sample_ids = array_ops.where( select_sample, tf.range(batch_size), gen_array_ops.fill([batch_size], -1)) where_sampling = math_ops.cast( array_ops.where(sample_ids > -1), tf.int32) where_not_sampling = math_ops.cast( array_ops.where(sample_ids <= -1), tf.int32) _estimate = array_ops.gather_nd(estimate, where_sampling) _true = array_ops.gather_nd(true, where_not_sampling) base_shape = array_ops.shape(true) result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape) result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape) result = result1 + result2 return result1 + result2
class BeamSearch(object): def __init__(self, model, config, step): self.config = config self.model = model.to(device) self._decode_dir = os.path.join(config.log_root, 'decode_S%s' % str(step)) self._rouge_ref = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec = os.path.join(self._decode_dir, 'rouge_dec') if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir) self.vocab = Vocab(config.vocab_file, config.vocab_size) self.test_data = CNNDMDataset('test', config.data_path, config, self.vocab) def sort_beams(self, beams): return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) @staticmethod def report_rouge(ref_path, dec_path): print("Now starting ROUGE eval...") files_rouge = FilesRouge(dec_path, ref_path) scores = files_rouge.get_scores(avg=True) logging(str(scores)) #@staticmethod def get_summary(self, best_summary, batch): # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = output2words( output_ids, self.vocab, (batch.art_oovs[0] if self.config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index('<end>') decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words decoded_abstract = ' '.join(decoded_words) return decoded_abstract def decode(self): config = self.config start = time.time() counter = 0 test_loader = DataLoader( self.test_data, batch_size=1, shuffle=False, collate_fn=Collate(beam_size=config.beam_size)) ref = open(self._rouge_ref, 'w') dec = open(self._rouge_dec, 'w') for batch in test_loader: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) original_abstract = batch.original_abstract[0] decoded_abstract = self.get_summary(best_summary, batch) ref.write(original_abstract + '\n') dec.write(decoded_abstract + '\n') counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() print("Decoder has finished reading dataset for single_pass.") ref.close() dec.close() self.report_rouge(self._rouge_ref, self._rouge_dec) def beam_search(self, batch): config = self.config #batch should have only one example enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \ get_input_from_batch(batch, config, device) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_0 = self.model.reduce_state(encoder_hidden) dec_h, dec_c = s_t_0 # 1 x 2*hidden_size dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() #decoder batch preparation, it has beam_size example initially everything is repeated beams = [ Beam(tokens=[self.vocab.word2id('<start>')], log_probs=[0.0], state=(dec_h[0], dec_c[0]), context=c_t_0[0], coverage=(coverage_t_0[0] if config.is_coverage else None)) for _ in range(config.beam_size) ] results = [] steps = 0 while steps < config.max_dec_steps and len(results) < config.beam_size: latest_tokens = [h.latest_token for h in beams] latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id('<unk>') \ for t in latest_tokens] y_t_1 = Variable(torch.tensor(latest_tokens)).to(device) all_state_h = [] all_state_c = [] all_context = [] for h in beams: state_h, state_c = h.state all_state_h.append(state_h) all_state_c.append(state_c) all_context.append(h.context) s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) c_t_1 = torch.stack(all_context, 0) coverage_t_1 = None if config.is_coverage: all_coverage = [] for h in beams: all_coverage.append(h.coverage) coverage_t_1 = torch.stack(all_coverage, 0) final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps) log_probs = torch.log(final_dist) topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2) dec_h, dec_c = s_t dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() all_beams = [] num_orig_beams = 1 if steps == 0 else len(beams) for i in range(num_orig_beams): h = beams[i] state_i = (dec_h[i], dec_c[i]) context_i = c_t[i] coverage_i = (coverage_t[i] if config.is_coverage else None) for j in range(config.beam_size * 2): # for each of the top 2*beam_size hyps: new_beam = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item(), state=state_i, context=context_i, coverage=coverage_i) all_beams.append(new_beam) beams = [] for h in self.sort_beams(all_beams): if h.latest_token == self.vocab.word2id('<end>'): if steps >= config.min_dec_steps: results.append(h) else: beams.append(h) if len(beams) == config.beam_size or len( results) == config.beam_size: break steps += 1 if len(results) == 0: results = beams beams_sorted = self.sort_beams(results) return beams_sorted[0]
class BeamSearch(object): def __init__(self, model_file_path, data_path, data_class='val'): self.data_class = data_class if self.data_class not in ['val', 'test']: print("data_class must be 'val' or 'test'.") raise ValueError # model_file_path e.g. --> ../log/{MODE NAME}/best_model/model_best_XXXXX model_name = os.path.basename(model_file_path) # log_root e.g. --> ../log/{MODE NAME}/ log_root = os.path.dirname(os.path.dirname(model_file_path)) # _decode_dir e.g. --> ../log/{MODE NAME}/decode_model_best_XXXXX/ self._decode_dir = os.path.join(log_root, 'decode_%s' % (model_name)) self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') self._result_path = os.path.join(self._decode_dir, 'result_%s_%s.txt' \ % (model_name, self.data_class)) # remove result file if exist if os.path.isfile(self._result_path): os.remove(self._result_path) for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]: if not os.path.exists(p): os.mkdir(p) self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(data_path, self.vocab, mode='decode', batch_size=config.beam_size, single_pass=True) time.sleep(5) self.model = Model(model_file_path, is_eval=True) def sort_beams(self, beams): return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) def beam_search(self, batch): # batch should have only one example enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \ get_input_from_batch(batch, use_cuda) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens) s_t_0 = self.model.reduce_state(encoder_hidden) dec_h, dec_c = s_t_0 # 1 x 2H dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() # decoder batch preparation, it has beam_size example initially everything is repeated beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0], state=(dec_h[0], dec_c[0]), context = c_t_0[0], coverage=(coverage_t_0[0] if config.is_coverage else None)) for _ in range(config.beam_size)] results = [] steps = 0 while steps < config.max_dec_steps and len(results) < config.beam_size: latest_tokens = [h.latest_token for h in beams] latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \ for t in latest_tokens] y_t_1 = Variable(torch.LongTensor(latest_tokens)) if use_cuda: y_t_1 = y_t_1.cuda() all_state_h =[] all_state_c = [] all_context = [] for h in beams: state_h, state_c = h.state all_state_h.append(state_h) all_state_c.append(state_c) all_context.append(h.context) s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) c_t_1 = torch.stack(all_context, 0) coverage_t_1 = None if config.is_coverage: all_coverage = [] for h in beams: all_coverage.append(h.coverage) coverage_t_1 = torch.stack(all_coverage, 0) final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps) log_probs = torch.log(final_dist) topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2) dec_h, dec_c = s_t dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() all_beams = [] num_orig_beams = 1 if steps == 0 else len(beams) for i in range(num_orig_beams): h = beams[i] state_i = (dec_h[i], dec_c[i]) context_i = c_t[i] coverage_i = (coverage_t[i] if config.is_coverage else None) for j in range(config.beam_size * 2): # for each of the top 2*beam_size hyps: new_beam = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item(), state=state_i, context=context_i, coverage=coverage_i) all_beams.append(new_beam) beams = [] for h in self.sort_beams(all_beams): if h.latest_token == self.vocab.word2id(data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: beams.append(h) if len(beams) == config.beam_size or len(results) == config.beam_size: break steps += 1 if len(results) == 0: results = beams beams_sorted = self.sort_beams(results) return beams_sorted[0] def decode(self): start = time.time() counter = 0 bleu_scores = [] batch = self.batcher.next_batch() while batch is not None: # Run beam search to get best Hypothesis best_summary = self.beam_search(batch) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words(output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_articles = batch.original_articles[0] original_abstracts = batch.original_abstracts_sents[0] reference = original_abstracts[0].strip().split() bleu = nltk.translate.bleu_score.sentence_bleu([reference], decoded_words, weights = (0.5, 0.5)) bleu_scores.append(bleu) # write_for_rouge(original_abstracts, decoded_words, counter, # self._rouge_ref_dir, self._rouge_dec_dir) write_for_result(original_articles, original_abstracts, decoded_words, \ self._result_path, self.data_class) counter += 1 if counter % 1000 == 0: print('%d example in %d sec'%(counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() ''' # uncomment this if you successfully install `pyrouge` print("Decoder has finished reading dataset for single_pass.") print("Now starting ROUGE eval...") results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) rouge_log(results_dict, self._decode_dir) ''' if self.data_class == 'val': print('Average BLEU score:', np.mean(bleu_scores)) with open(self._result_path, "a") as f: print('Average BLEU score:', np.mean(bleu_scores), file=f) def get_processed_path(self): # ../log/{MODE NAME}/decode_model_best_XXXXX/result_model_best_2800_{data_class}.txt input_path = self._result_path temp = os.path.splitext(input_path) # ../log/{MODE NAME}/decode_model_best_XXXXX/result_model_best_2800_{data_class}_processed.txt output_path = temp[0] + "_processed" + temp[1] return input_path, output_path
def main(): embedding_dict_file = os.path.join(os.path.dirname(hps.word_count_path), 'emb_dict_50000.pkl') vocab = Vocab(hps.word_count_path, hps.glove_path, hps.embedding_dim, hps.max_vocab_size, embedding_dict_file) train_file = os.path.join(hps.data_path, 'train_raw.json') dev_file = os.path.join(hps.data_path, 'dev_raw.json') #'dev_raw.json') if (not os.path.exists(train_file)) \ or (not os.path.exists(dev_file)): raise Exception( 'train and dev data not exist in data_path, please check') if hps.save and not hps.exp_dir: raise Exception( 'please specify exp_dir when you want to save experiment info') print(vars(hps)) if hps.save: utils.save_hps(hps.exp_dir, hps) net = PointerNet(hps, vocab.emb_mat) net = net.cuda() model_parameters = list(filter(lambda p: p.requires_grad, net.parameters())) print('the number of parameters in model:', sum(p.numel() for p in model_parameters)) optimizer = optim.Adam(model_parameters) train_data_batcher = Batcher(train_file, vocab, hps, hps.single_pass) dev_data_batcher = Batcher(dev_file, vocab, hps, hps.single_pass) if hps.reward_metric == 'bleu': reward = get_batch_bleu global_step = 0 dev_loss_track = [] min_dev_loss = math.inf for i in range(hps.num_epoch): epoch_loss_track = [] train_data_batcher.setup() while True: start = time.time() try: batch = train_data_batcher.next_batch() #print('get next batch time:', time.time()-start) except StopIteration: # do evaluation here, if necessary, to save best model dev_data_batcher.setup() dev_loss = run_eval(dev_data_batcher, net) print( "epoch {}: avg train loss: {:>10.4f}, dev_loss: {:>10.4f}". format(i + 1, sum(epoch_loss_track) / len(epoch_loss_track), dev_loss)) dev_loss_track.append(dev_loss) if i > hps.early_stopping_from: last5devloss = dev_loss_track[i] + dev_loss_track[ i - 1] + dev_loss_track[i - 2] + dev_loss_track[ i - 3] + dev_loss_track[i - 4] last10devloss = dev_loss_track[i - 5] + dev_loss_track[ i - 6] + dev_loss_track[i - 7] + dev_loss_track[ i - 8] + dev_loss_track[i - 9] if hps.early_stopping_from and last5devloss >= last10devloss: print("early stopping by dev_loss!") sys.exit() if dev_loss < min_dev_loss: min_dev_loss = dev_loss if hps.save: utils.save_model(hps.exp_dir, net, min_dev_loss) break paragraph_tensor = torch.tensor(batch.enc_batch, dtype=torch.int64, requires_grad=False).cuda() question_tensor = torch.tensor(batch.dec_batch, dtype=torch.int64, requires_grad=False).cuda() answer_position_tensor = torch.tensor(batch.ans_indices, dtype=torch.int64, requires_grad=False).cuda() target_tensor = torch.tensor(batch.target_batch, dtype=torch.int64, requires_grad=False).cuda() paragraph_batch_extend_vocab = None max_para_oovs = None if hps.pointer_gen: paragraph_batch_extend_vocab = torch.tensor( batch.enc_batch_extend_vocab, dtype=torch.int64, requires_grad=False).cuda() max_para_oovs = batch.max_para_oovs optimizer.zero_grad() net.train() vocab_scores, vocab_dists, attn_dists, final_dists = net( paragraph_tensor, question_tensor, answer_position_tensor, paragraph_batch_extend_vocab, max_para_oovs) dec_padding_mask = torch.ne(target_tensor, 0).float().cuda() # for self-critic if hps.self_critic: greedy_seq = [ torch.argmax(dist, dim=1, keepdim=True) for dist in final_dists ] # each dist = [batch_size, vsize] greedy_seq_tensor = torch.cat(greedy_seq, dim=1) # [batch_size, seq_len] sample_seq = [] for dist in final_dists: m = torch.distributions.categorical.Categorical(probs=dist) sample_seq.append(m.sample()) # each is [batch_size,] sample_seq_tensor = torch.stack(sample_seq, dim=1) if hps.pointer_gen: loss_per_step = [] for dist, sample_tgt in zip(final_dists, sample_seq): # dist = [batch_size, extended_vsize] probs = torch.gather( dist, 1, sample_tgt.unsqueeze(1)).squeeze() losses = -torch.log(probs) loss_per_step.append(losses) # a list of [batch_size,] rl_loss = mask_and_avg(loss_per_step, dec_padding_mask, batch_average=False, step_average=False) # this rl_loss = [batch_size, ] else: # a list of dec_max_len (vocab_scores) loss_batch_by_step = F.cross_entropy( torch.stack(vocab_scores, dim=1).reshape(-1, vocab.size()), sample_seq_tensor.reshape(-1), size_average=False, reduce=False) # loss [batch_size*dec_max_len,] mask_loss_batch_by_step = loss_batch_by_step * dec_padding_mask.reshape( -1) batch_size = vocab_scores[0].size(0) rl_loss = torch.sum(mask_loss_batch_by_step.reshape( batch_size, -1), dim=1) r1 = reward(target_tensor, greedy_seq_tensor) r2 = reward(target_tensor, sample_seq_tensor) reward_diff = r1 - r2 final_rl_loss = reward_diff * rl_loss loss = torch.mean(final_rl_loss) print( 'r1: %.3f, r2: %.3f, reward_diff: %.3f, final rl loss: %.3f, loss batch mean: %.3f' % (torch.max(r1).item(), torch.max(r2).item(), torch.max(reward_diff).item(), torch.max(final_rl_loss).item(), loss.item())) # for maximum likelihood if hps.maxium_likelihood: if hps.pointer_gen: loss_per_step = [] for dec_step, dist in enumerate(final_dists): # dist = [batch_size, extended_vsize] targets = target_tensor[:, dec_step] gold_probs = torch.gather( dist, 1, targets.unsqueeze(1)).squeeze() losses = -torch.log(gold_probs) loss_per_step.append(losses) # a list of [batch_size,] loss = mask_and_avg(loss_per_step, dec_padding_mask) else: # a list of dec_max_len (vocab_scores) loss_batch_by_step = F.cross_entropy( torch.stack(vocab_scores, dim=1).reshape(-1, vocab.size()), target_tensor.reshape(-1), size_average=False, reduce=False) # loss [batch_size*dec_max_len,] loss = torch.sum(loss_batch_by_step * dec_padding_mask.reshape(-1)) / torch.sum( dec_padding_mask) epoch_loss_track.append(loss.item()) global_step += 1 loss.backward() nn.utils.clip_grad_norm_(net.parameters(), max_norm=hps.norm_limit) optimizer.step() #print('time one step:', time.time()-start) if (global_step == 1) or (global_step % hps.print_every == 0): print('Step {:>5}: ave loss: {:>10.4f}, speed: {:.1f} case/s'. format(global_step, sum(epoch_loss_track) / len(epoch_loss_track), hps.batch_size / (time.time() - start)))
def main(): vocab = Vocab(hps.word_count_path, hps.glove_path, hps.embedding_dim) net = PointerNet(hps, vocab.emb_mat) net = net.cuda() data_batcher = batcher(hps.data_path, vocab, hps, hps.single_pass) model_parameters = filter(lambda p: p.requires_grad, net.parameters()) optimizer = optim.Adam(model_parameters) loss_track = [] global_step = 0 while True: start = time.time() batch = next(data_batcher) #batch = pickle.load(open('one_batch.pkl', 'rb')) paragraph_tensor = torch.tensor(batch.enc_batch, dtype=torch.int64, requires_grad=False).cuda() question_tensor = torch.tensor(batch.dec_batch, dtype=torch.int64, requires_grad=False).cuda() answer_position_tensor = torch.tensor(batch.ans_indices, dtype=torch.int64, requires_grad=False).cuda() target_tensor = torch.tensor(batch.target_batch, dtype=torch.int64, requires_grad=False).cuda() paragraph_batch_extend_vocab = None max_para_oovs = None if hps.pointer_gen: paragraph_batch_extend_vocab = torch.tensor(batch.enc_batch_extend_vocab, dtype=torch.int64, requires_grad=False).cuda() max_para_oovs = batch.max_para_oovs vocab_scores, vocab_dists, attn_dists, final_dists = net(paragraph_tensor, question_tensor, answer_position_tensor, paragraph_batch_extend_vocab, max_para_oovs) optimizer.zero_grad() dec_padding_mask = torch.ne(target_tensor, 0).float().cuda() if hps.pointer_gen: loss_per_step = [] for dec_step, dist in enumerate(final_dists): # dist = [batch_size, extended_vsize] targets = target_tensor[:,dec_step] gold_probs = torch.gather(dist, 1, targets.unsqueeze(1)).squeeze() losses = -torch.log(gold_probs) loss_per_step.append(losses) # a list of [batch_size,] loss = mask_and_avg(loss_per_step, dec_padding_mask) else: # a list of dec_max_len (vocab_scores) loss_batch_by_step = F.cross_entropy(torch.cat(vocab_scores, dim=1).reshape(-1, vocab.size()), target_tensor.reshape(-1), size_average=False, reduce=False) # loss [batch_size*dec_max_len,] loss = torch.sum(loss_batch_by_step * dec_padding_mask.reshape(-1))/torch.sum(dec_padding_mask) loss_track.append(loss.item()) global_step += 1 loss.backward() optimizer.step() if global_step % hps.print_every == 0: print('Step {:>10}: ave loss: {:>10.4f}, speed: {:.4f}/case'.format(global_step, sum(loss_track)/len(loss_track), (time.time()-start)/hps.batch_size)) loss_track = []