示例#1
0
    def _run_batch_basic(self, dialogue_batch, sess, summary_map, test=False):
        '''
        Run truncated RNN through a sequence of batch examples.
        '''
        encoder_init_state = None
        matched_items = dialogue_batch['matched_items']
        for batch in dialogue_batch['batch_seq']:
            feed_dict = self._get_feed_dict(batch, encoder_init_state, matched_items=matched_items)
            if test:
                logits, final_state, loss, seq_loss, total_loss = sess.run([
                    self.model.decoder.output_dict['logits'],
                    self.model.decoder.output_dict['final_state'],
                    self.model.loss, self.model.seq_loss, self.model.total_loss],
                    feed_dict=feed_dict)
            else:
                _, logits, final_state, loss, seq_loss, gn = sess.run([
                    self.train_op,
                    self.model.decoder.output_dict['logits'],
                    self.model.decoder.output_dict['final_state'],
                    self.model.loss, self.model.seq_loss,
                    self.grad_norm], feed_dict=feed_dict)
            encoder_init_state = final_state

            if self.verbose:
                preds = np.argmax(logits, axis=2)
                self._print_batch(batch, preds, seq_loss)

            if test:
                logstats.update_summary_map(summary_map, {'total_loss': total_loss[0], 'num_tokens': total_loss[1]})
            else:
                logstats.update_summary_map(summary_map, {'loss': loss})
                logstats.update_summary_map(summary_map, {'grad_norm': gn})
示例#2
0
    def _run_batch_graph(self, dialogue_batch, sess, summary_map, test=False):
        '''
        Run truncated RNN through a sequence of batch examples with knowledge graphs.
        '''
        encoder_init_state = None
        utterances = None
        graphs = dialogue_batch['graph']
        matched_items = dialogue_batch['matched_items']
        for i, batch in enumerate(dialogue_batch['batch_seq']):
            graph_data = graphs.get_batch_data(batch['encoder_tokens'],
                                               batch['decoder_tokens'],
                                               batch['encoder_entities'],
                                               batch['decoder_entities'],
                                               utterances, self.vocab)
            init_checklists = graphs.get_zero_checklists(1)
            feed_dict = self._get_feed_dict(batch, encoder_init_state,
                                            graph_data, graphs, self.data.copy,
                                            init_checklists,
                                            graph_data['encoder_nodes'],
                                            graph_data['decoder_nodes'],
                                            matched_items)
            if test:
                logits, final_state, utterances, loss, seq_loss, total_loss = sess.run(
                    [
                        self.model.decoder.output_dict['logits'],
                        self.model.decoder.output_dict['final_state'],
                        self.model.decoder.output_dict['utterances'],
                        self.model.loss, self.model.seq_loss,
                        self.model.total_loss
                    ],
                    feed_dict=feed_dict)
            else:
                _, logits, final_state, utterances, loss, seq_loss, gn = sess.run(
                    [
                        self.train_op,
                        self.model.decoder.output_dict['logits'],
                        self.model.decoder.output_dict['final_state'],
                        self.model.decoder.output_dict['utterances'],
                        self.model.loss, self.model.seq_loss, self.grad_norm
                    ],
                    feed_dict=feed_dict)
            # NOTE: final_state = (rnn_state, attn, context)
            encoder_init_state = final_state[0]

            if self.verbose:
                preds = np.argmax(logits, axis=2)
                if self.data.copy:
                    preds = graphs.copy_preds(preds,
                                              self.data.mappings['vocab'].size)
                self._print_batch(batch, preds, seq_loss)

            if test:
                logstats.update_summary_map(summary_map, {
                    'total_loss': total_loss[0],
                    'num_tokens': total_loss[1]
                })
            else:
                logstats.update_summary_map(summary_map, {'loss': loss})
                logstats.update_summary_map(summary_map, {'grad_norm': gn})
示例#3
0
    def learn(self, args, config, stats_file, ckpt=None, split='train'):
        logstats.init(stats_file)
        assert args.min_epochs <= args.max_epochs

        assert args.optimizer in optim.keys()
        optimizer = optim[args.optimizer](args.learning_rate)

        # Gradient
        grads_and_vars = optimizer.compute_gradients(self.model.loss)
        if args.grad_clip > 0:
            min_grad, max_grad = -1. * args.grad_clip, args.grad_clip
            clipped_grads_and_vars = [
                (tf.clip_by_value(grad, min_grad, max_grad), var)
                for grad, var in grads_and_vars
            ]
        else:
            clipped_grads_and_vars = grads_and_vars
        # TODO: clip has problem with indexedslices, don't use
        #self.clipped_grads = [grad for grad, var in clipped_grads_and_vars]
        #self.grads = [grad for grad, var in grads_and_vars]
        self.grad_norm = tf.global_norm([grad for grad, var in grads_and_vars])
        self.clipped_grad_norm = tf.global_norm(
            [grad for grad, var in clipped_grads_and_vars])

        # Optimize
        self.train_op = optimizer.apply_gradients(clipped_grads_and_vars)

        # Training loop
        train_data = self.data.generator(split, self.batch_size)
        num_per_epoch = train_data.next()
        step = 0
        saver = tf.train.Saver()
        save_path = os.path.join(args.checkpoint, 'tf_model.ckpt')
        best_saver = tf.train.Saver(max_to_keep=1)
        best_checkpoint = args.checkpoint + '-best'
        if not os.path.isdir(best_checkpoint):
            os.mkdir(best_checkpoint)
        best_save_path = os.path.join(best_checkpoint, 'tf_model.ckpt')
        best_loss = float('inf')
        # Number of iterations without any improvement
        num_epoch_no_impr = 0

        # Testing
        with tf.Session(config=config) as sess:
            tf.initialize_all_variables().run()
            if args.init_from:
                saver.restore(sess, ckpt.model_checkpoint_path)
            summary_map = {}
            #for epoch in xrange(args.max_epochs):
            epoch = 1
            while True:
                print '================== Epoch %d ==================' % (
                    epoch)
                for i in xrange(num_per_epoch):
                    start_time = time.time()
                    self._run_batch(train_data.next(),
                                    sess,
                                    summary_map,
                                    test=False)
                    end_time = time.time()
                    logstats.update_summary_map(summary_map, \
                            {'time(s)/batch': end_time - start_time, \
                             'memory(MB)': memory()})
                    step += 1
                    if step % args.print_every == 0 or step % num_per_epoch == 0:
                        print '{}/{} (epoch {}) {}'.format(
                            i + 1, num_per_epoch, epoch,
                            logstats.summary_map_to_str(summary_map))
                        summary_map = {}  # Reset
                step = 0

                # Save model after each epoch
                print 'Save model checkpoint to', save_path
                saver.save(sess, save_path, global_step=epoch)

                # Evaluate on dev
                for split, test_data, num_batches in self.evaluator.dataset():
                    print '================== Eval %s ==================' % split
                    print '================== Perplexity =================='
                    start_time = time.time()
                    loss = self.test_loss(sess, test_data, num_batches)
                    print 'loss=%.4f time(s)=%.4f' % (loss,
                                                      time.time() - start_time)
                    print '================== Sampling =================='
                    start_time = time.time()
                    bleu, (ent_prec, ent_recall,
                           ent_f1) = self.evaluator.test_bleu(
                               sess, test_data, num_batches)
                    print 'bleu=%.4f/%.4f/%.4f entity_f1=%.4f/%.4f/%.4f time(s)=%.4f' % (
                        bleu[0], bleu[1], bleu[2], ent_prec, ent_recall,
                        ent_f1, time.time() - start_time)

                    # Start to record no improvement epochs
                    if split == 'dev' and epoch > args.min_epochs:
                        if loss < best_loss * 0.995:
                            num_epoch_no_impr = 0
                        else:
                            num_epoch_no_impr += 1

                    if split == 'dev' and loss < best_loss:
                        print 'New best model'
                        best_loss = loss
                        best_saver.save(sess, best_save_path)
                        logstats.add(
                            'best_model', {
                                'bleu-4': bleu[0],
                                'bleu-3': bleu[1],
                                'bleu-2': bleu[2],
                                'entity_precision': ent_prec,
                                'entity_recall': ent_recall,
                                'entity_f1': ent_f1,
                                'loss': loss,
                                'epoch': epoch
                            })

                # Early stop when no improvement
                if (epoch > args.min_epochs
                        and num_epoch_no_impr >= 5) or epoch > args.max_epochs:
                    break
                epoch += 1