def _run_batch_basic(self, dialogue_batch, sess, summary_map, test=False): ''' Run truncated RNN through a sequence of batch examples. ''' encoder_init_state = None matched_items = dialogue_batch['matched_items'] for batch in dialogue_batch['batch_seq']: feed_dict = self._get_feed_dict(batch, encoder_init_state, matched_items=matched_items) if test: logits, final_state, loss, seq_loss, total_loss = sess.run([ self.model.decoder.output_dict['logits'], self.model.decoder.output_dict['final_state'], self.model.loss, self.model.seq_loss, self.model.total_loss], feed_dict=feed_dict) else: _, logits, final_state, loss, seq_loss, gn = sess.run([ self.train_op, self.model.decoder.output_dict['logits'], self.model.decoder.output_dict['final_state'], self.model.loss, self.model.seq_loss, self.grad_norm], feed_dict=feed_dict) encoder_init_state = final_state if self.verbose: preds = np.argmax(logits, axis=2) self._print_batch(batch, preds, seq_loss) if test: logstats.update_summary_map(summary_map, {'total_loss': total_loss[0], 'num_tokens': total_loss[1]}) else: logstats.update_summary_map(summary_map, {'loss': loss}) logstats.update_summary_map(summary_map, {'grad_norm': gn})
def _run_batch_graph(self, dialogue_batch, sess, summary_map, test=False): ''' Run truncated RNN through a sequence of batch examples with knowledge graphs. ''' encoder_init_state = None utterances = None graphs = dialogue_batch['graph'] matched_items = dialogue_batch['matched_items'] for i, batch in enumerate(dialogue_batch['batch_seq']): graph_data = graphs.get_batch_data(batch['encoder_tokens'], batch['decoder_tokens'], batch['encoder_entities'], batch['decoder_entities'], utterances, self.vocab) init_checklists = graphs.get_zero_checklists(1) feed_dict = self._get_feed_dict(batch, encoder_init_state, graph_data, graphs, self.data.copy, init_checklists, graph_data['encoder_nodes'], graph_data['decoder_nodes'], matched_items) if test: logits, final_state, utterances, loss, seq_loss, total_loss = sess.run( [ self.model.decoder.output_dict['logits'], self.model.decoder.output_dict['final_state'], self.model.decoder.output_dict['utterances'], self.model.loss, self.model.seq_loss, self.model.total_loss ], feed_dict=feed_dict) else: _, logits, final_state, utterances, loss, seq_loss, gn = sess.run( [ self.train_op, self.model.decoder.output_dict['logits'], self.model.decoder.output_dict['final_state'], self.model.decoder.output_dict['utterances'], self.model.loss, self.model.seq_loss, self.grad_norm ], feed_dict=feed_dict) # NOTE: final_state = (rnn_state, attn, context) encoder_init_state = final_state[0] if self.verbose: preds = np.argmax(logits, axis=2) if self.data.copy: preds = graphs.copy_preds(preds, self.data.mappings['vocab'].size) self._print_batch(batch, preds, seq_loss) if test: logstats.update_summary_map(summary_map, { 'total_loss': total_loss[0], 'num_tokens': total_loss[1] }) else: logstats.update_summary_map(summary_map, {'loss': loss}) logstats.update_summary_map(summary_map, {'grad_norm': gn})
def learn(self, args, config, stats_file, ckpt=None, split='train'): logstats.init(stats_file) assert args.min_epochs <= args.max_epochs assert args.optimizer in optim.keys() optimizer = optim[args.optimizer](args.learning_rate) # Gradient grads_and_vars = optimizer.compute_gradients(self.model.loss) if args.grad_clip > 0: min_grad, max_grad = -1. * args.grad_clip, args.grad_clip clipped_grads_and_vars = [ (tf.clip_by_value(grad, min_grad, max_grad), var) for grad, var in grads_and_vars ] else: clipped_grads_and_vars = grads_and_vars # TODO: clip has problem with indexedslices, don't use #self.clipped_grads = [grad for grad, var in clipped_grads_and_vars] #self.grads = [grad for grad, var in grads_and_vars] self.grad_norm = tf.global_norm([grad for grad, var in grads_and_vars]) self.clipped_grad_norm = tf.global_norm( [grad for grad, var in clipped_grads_and_vars]) # Optimize self.train_op = optimizer.apply_gradients(clipped_grads_and_vars) # Training loop train_data = self.data.generator(split, self.batch_size) num_per_epoch = train_data.next() step = 0 saver = tf.train.Saver() save_path = os.path.join(args.checkpoint, 'tf_model.ckpt') best_saver = tf.train.Saver(max_to_keep=1) best_checkpoint = args.checkpoint + '-best' if not os.path.isdir(best_checkpoint): os.mkdir(best_checkpoint) best_save_path = os.path.join(best_checkpoint, 'tf_model.ckpt') best_loss = float('inf') # Number of iterations without any improvement num_epoch_no_impr = 0 # Testing with tf.Session(config=config) as sess: tf.initialize_all_variables().run() if args.init_from: saver.restore(sess, ckpt.model_checkpoint_path) summary_map = {} #for epoch in xrange(args.max_epochs): epoch = 1 while True: print '================== Epoch %d ==================' % ( epoch) for i in xrange(num_per_epoch): start_time = time.time() self._run_batch(train_data.next(), sess, summary_map, test=False) end_time = time.time() logstats.update_summary_map(summary_map, \ {'time(s)/batch': end_time - start_time, \ 'memory(MB)': memory()}) step += 1 if step % args.print_every == 0 or step % num_per_epoch == 0: print '{}/{} (epoch {}) {}'.format( i + 1, num_per_epoch, epoch, logstats.summary_map_to_str(summary_map)) summary_map = {} # Reset step = 0 # Save model after each epoch print 'Save model checkpoint to', save_path saver.save(sess, save_path, global_step=epoch) # Evaluate on dev for split, test_data, num_batches in self.evaluator.dataset(): print '================== Eval %s ==================' % split print '================== Perplexity ==================' start_time = time.time() loss = self.test_loss(sess, test_data, num_batches) print 'loss=%.4f time(s)=%.4f' % (loss, time.time() - start_time) print '================== Sampling ==================' start_time = time.time() bleu, (ent_prec, ent_recall, ent_f1) = self.evaluator.test_bleu( sess, test_data, num_batches) print 'bleu=%.4f/%.4f/%.4f entity_f1=%.4f/%.4f/%.4f time(s)=%.4f' % ( bleu[0], bleu[1], bleu[2], ent_prec, ent_recall, ent_f1, time.time() - start_time) # Start to record no improvement epochs if split == 'dev' and epoch > args.min_epochs: if loss < best_loss * 0.995: num_epoch_no_impr = 0 else: num_epoch_no_impr += 1 if split == 'dev' and loss < best_loss: print 'New best model' best_loss = loss best_saver.save(sess, best_save_path) logstats.add( 'best_model', { 'bleu-4': bleu[0], 'bleu-3': bleu[1], 'bleu-2': bleu[2], 'entity_precision': ent_prec, 'entity_recall': ent_recall, 'entity_f1': ent_f1, 'loss': loss, 'epoch': epoch }) # Early stop when no improvement if (epoch > args.min_epochs and num_epoch_no_impr >= 5) or epoch > args.max_epochs: break epoch += 1