Ejemplo n.º 1
0
    def run_internal_eval(self,
                          eval_model, eval_sess, model_dir, summary_writer, use_test_set=True):
        """Compute internal evaluation (perplexity) for both dev / test."""
        with eval_model.graph.as_default():
            loaded_eval_model, global_step = model_helper.create_or_load_model(
                eval_model.model, model_dir, eval_sess, "eval")

        dev_file = self.config.dev_data

        dev_eval_iterator_feed_dict = {
            eval_model.eval_file_placeholder: dev_file
        }

        dev_ppl = self._internal_eval(loaded_eval_model, global_step, eval_sess,
                                      eval_model.iterator, dev_eval_iterator_feed_dict,
                                      summary_writer, "dev")
        log.add_summary(summary_writer, global_step, "dev_ppl", dev_ppl)

        if dev_ppl < self.config.best_dev_ppl:
            loaded_eval_model.saver.save(eval_sess,
                                         os.path.join(self.config.best_dev_ppl_dir, 'taware.ckpt'),
                                         global_step=global_step)

        test_ppl = None
        if use_test_set:
            test_file = self.config.test_data

            test_eval_iterator_feed_dict = {
                eval_model.eval_file_placeholder: test_file
            }
            test_ppl = self._internal_eval(loaded_eval_model, global_step, eval_sess,
                                           eval_model.iterator, test_eval_iterator_feed_dict,
                                           summary_writer, "test")
        return dev_ppl, test_ppl
Ejemplo n.º 2
0
    def __init__(self, hparams):
        self.hparams = hparams
        # print("====test__init__==\n")
        # Data locations
        self.out_dir = hparams.out_dir
        # print("our_dir:", self.out_dir)
        self.model_dir = os.path.join(self.out_dir, 'ckpts')
        # print("model_dir:", self.model_dir)
        # Create models
        attention_option = hparams.attention_option

        if attention_option:
            model_creator = AttentionModel
        else:
            model_creator = BasicModel

        self.infer_model = model_helper.create_infer_model(
            hparams=hparams, model_creator=model_creator)

        # Sessions
        config_proto = utils.get_config_proto()
        self.infer_sess = tf.Session(config=config_proto,
                                     graph=self.infer_model.graph)

        # EOS
        self.tgt_eos = Vocabulary.EOS.encode("utf-8")
        # Load infer model
        with self.infer_model.graph.as_default():
            self.loaded_infer_model, self.global_step = model_helper.create_or_load_model(
                self.infer_model.model, self.model_dir, self.infer_sess,
                "infer")
Ejemplo n.º 3
0
def run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                            summary_writer, save_on_best_dev):
    with infer_model.graph.as_default():
        # Load the model from checkpoint. It automatically loads the latest checkpoint
        loaded_infer_model, global_step = model_helper.create_or_load_model(
            model=infer_model.model,
            model_dir=model_dir,
            session=infer_sess,
            name="infer"
        )
        # Fill the feed_dict for the evaluation
        dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
        dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)

        inference_dev_data = inference.load_data(dev_src_file)
        dev_infer_iterator_feed_dict = {
            infer_model.src_placeholder: inference_dev_data,
            infer_model.batch_size_placeholder: hparams.infer_batch_size
        }

        dev_scores = _external_eval(
            model=loaded_infer_model,
            global_step=global_step,
            sess=infer_sess,
            hparams=hparams,
            iterator=infer_model.iterator,
            iterator_feed_dict=dev_infer_iterator_feed_dict,
            tgt_file=dev_tgt_file,
            label="dev",
            summary_writer=summary_writer,
            save_on_best_dev=save_on_best_dev
        )

        test_scores = None
        if hparams.test_prefix:
            # Create the test data
            test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
            test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
            inference_test_data = inference.load_data(test_src_file)
            test_infer_iterator_feed_dict = {
                infer_model.src_placeholder: inference_test_data,
                infer_model.batch_size_placeholder: hparams.infer_batch_size
            }
            # Run evaluation on the test dataset
            test_scores = _external_eval(
                model=loaded_infer_model,
                global_step=global_step,
                sess=infer_sess,
                hparams=hparams,
                iterator=infer_model.iterator,
                iterator_feed_dict=test_infer_iterator_feed_dict,
                tgt_file=test_tgt_file,
                label="test",
                summary_writer=summary_writer,
                save_on_best_dev=False  # We do not use the test set at all in training as that means overfitting
            )

    return dev_scores, test_scores, global_step
Ejemplo n.º 4
0
    def run_sample_decode(self, infer_model, infer_sess, model_dir, summary_writer, eval_data):
        """Sample decode a random sentence from src_data."""
        with infer_model.graph.as_default():
            loaded_infer_model, global_step = model_helper.create_or_load_model(
                infer_model.model, model_dir, infer_sess, "infer")

        self.__sample_decode(loaded_infer_model, global_step, infer_sess,
                            infer_model.iterator, eval_data,
                            infer_model.src_placeholder,
                            infer_model.batch_size_placeholder, summary_writer)
Ejemplo n.º 5
0
    def run_external_eval(self,
                          infer_model,
                          infer_sess,
                          model_dir,
                          summary_writer,
                          save_best_dev=True,
                          use_test_set=True):
        """Compute external evaluation (bleu, rouge, etc.) for both dev / test."""
        with infer_model.graph.as_default():
            loaded_infer_model, global_step = model_helper.create_or_load_model(
                infer_model.model, model_dir, infer_sess, "infer")

        dev_infer_iterator_feed_dict = {
            infer_model.src_placeholder: self._load_data(self.config.dev_data),
            infer_model.batch_size_placeholder: self.config.infer_batch_size,
        }

        dev_scores = self._external_eval(loaded_infer_model,
                                         global_step,
                                         infer_sess,
                                         infer_model.iterator,
                                         dev_infer_iterator_feed_dict,
                                         self.config.dev_data,
                                         "dev",
                                         summary_writer,
                                         save_on_best=save_best_dev)

        test_scores = None
        if use_test_set:
            test_file = self.config.test_data
            test_infer_iterator_feed_dict = {
                infer_model.src_placeholder: self._load_data(test_file),
                infer_model.batch_size_placeholder:
                self.config.infer_batch_size,
            }

            test_scores = self._external_eval(loaded_infer_model,
                                              global_step,
                                              infer_sess,
                                              infer_model.iterator,
                                              test_infer_iterator_feed_dict,
                                              test_file,
                                              "test",
                                              summary_writer,
                                              save_on_best=False)
        return dev_scores, test_scores, global_step
Ejemplo n.º 6
0
    def chat(self):
        """Accept a input str and get response by trained model."""
        model_dir = self.model_dir
        infer_model = self.infer_model
        infer_sess = self.infer_sess
        beam_width = self.hparams.beam_width

        # Load infer model
        with infer_model.graph.as_default():
            loaded_infer_model, global_step = model_helper.create_or_load_model(
                infer_model.model, model_dir, infer_sess, "infer")

        # Warm up jieba
        jieba.lcut("jieba")

        print("请输入字母'q' 或者 '退出'标示结束!\n\n")

        while True:
            input_str = input('Me > ')
            if not input_str.strip():
                continue
            if input_str == "q" or input_str == "退出":
                break

            input_seg = jieba.lcut(input_str)
            start_time = time.time()

            iterator_feed_dict = {
                infer_model.src_data_placeholder: input_seg,
                infer_model.batch_size_placeholder: 1
            }
            infer_sess.run(self.infer_model.iterator.initializer,
                           feed_dict=iterator_feed_dict)

            sample_words = loaded_infer_model.decode(infer_sess)

            if beam_width > 0:
                # Get a random answer.
                beam_id = random.randint(0, beam_width - 1)
                sample_words = sample_words[beam_id]

            response = self._get_response(sample_words)
            response = "".join(re.split(" ", response))

            print("AI > %s (%.4fs)" % (response, time.time() - start_time))
Ejemplo n.º 7
0
    def sample_decode(self, num_sentences=1):
        """Sample decode num_sentences random sentence from src_data."""
        model_dir = self.model_dir
        infer_model = self.infer_model
        infer_sess = self.infer_sess
        train_src_file = self.train_src_file
        train_tgt_file = self.train_tgt_file
        beam_width = self.hparams.beam_width

        start_time = time.time()

        # Load infer model
        with infer_model.graph.as_default():
            loaded_infer_model, global_step = model_helper.create_or_load_model(
                infer_model.model, model_dir, infer_sess, "infer")

        src_data = open(train_src_file, encoding='utf-8').readlines()
        tgt_data = open(train_tgt_file, encoding='utf-8').readlines()

        for _ in range(num_sentences):
            decode_id = random.randint(0, len(src_data) - 1)
            print("# Decoding sentence %d" % decode_id)

            iterator_feed_dict = {
                infer_model.src_data_placeholder: [src_data[decode_id]],
                infer_model.batch_size_placeholder: 1
            }
            infer_sess.run(
                self.infer_model.iterator.initializer,
                feed_dict=iterator_feed_dict)

            sample_words = loaded_infer_model.decode(infer_sess)

            if beam_width > 0:
                # get the top translation.
                sample_words = sample_words[0]

            response = self._get_response(sample_words)

            print("  src: %s" % src_data[decode_id], end='')
            print("  ref: %s" % tgt_data[decode_id], end='')
            print("  bot: %s" % response)
            print("  tim: %.4fs" % (time.time() - start_time))
Ejemplo n.º 8
0
    def _get_eval_perplexity(self, name):
        model_dir = self.model_dir
        eval_model = self.eval_model
        eval_sess = self.eval_sess

        with eval_model.graph.as_default():
            loaded_eval_model, global_step = model_helper.create_or_load_model(
                eval_model.model, model_dir, eval_sess, 'eval')

        dev_eval_iterator_feed_dict = {
            eval_model.src_file_placeholder: self.dev_src_file,
            eval_model.tgt_file_placeholder: self.dev_tgt_file
        }

        dev_ppl = eval_utils.internal_eval(
            eval_model, global_step, eval_sess, eval_model.iterator,
            dev_eval_iterator_feed_dict, name)

        return dev_ppl
Ejemplo n.º 9
0
def run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer):
    """Compute internal evaluation (perplexity) for both dev / test."""
    with eval_model.graph.as_default():
        # Load the latest checkpoint from file
        loaded_eval_model, global_step = model_helper.create_or_load_model(
            model=eval_model.model,
            model_dir=model_dir,
            session=eval_sess,
            name="eval"
        )
        # Fill the feed_dict for the evaluation
        dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
        dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
        dev_eval_iterator_feed_dict = {
            eval_model.src_file_placeholder: dev_src_file,
            eval_model.tgt_file_placeholder: dev_tgt_file
        }
        # Run evaluation on the development (validation) dataset
        dev_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess,
                                 iterator=eval_model.iterator,
                                 iterator_feed_dict=dev_eval_iterator_feed_dict,
                                 summary_writer=summary_writer,
                                 label='dev')

        test_ppl = None
        if hparams.test_prefix:
            # Create the test data
            test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
            test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
            test_eval_iterator_feed_dict = {
                eval_model.src_file_placeholder: test_src_file,
                eval_model.tgt_file_placeholder: test_tgt_file
            }
            # Run evaluation on the test dataset
            test_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess,
                                      iterator=eval_model.iterator,
                                      iterator_feed_dict=test_eval_iterator_feed_dict,
                                      summary_writer=summary_writer,
                                      label='test')

    return dev_ppl, test_ppl
Ejemplo n.º 10
0
    def run_internal_eval(self,
                          eval_model,
                          eval_sess,
                          model_dir,
                          summary_writer,
                          use_test_set=True):
        """Compute internal evaluation (perplexity) for both dev / test."""
        with eval_model.graph.as_default():
            loaded_eval_model, global_step = model_helper.create_or_load_model(
                eval_model.model, model_dir, eval_sess, "eval")

        dev_eval_iterator_feed_dict = {
            eval_model.eval_file_placeholder: self.config.dev_data
        }

        eval_sess.run(eval_model.iterator.initializer,
                      feed_dict=dev_eval_iterator_feed_dict)
        dev_ppl = model_helper.compute_perplexity(loaded_eval_model, eval_sess,
                                                  "dev")
        log.add_summary(summary_writer, global_step, "dev_ppl", dev_ppl)

        if dev_ppl < self.config.best_dev_ppl:
            loaded_eval_model.saver.save(
                eval_sess,
                os.path.join(self.config.best_dev_ppl_dir,
                             '{}.ckpt'.format(self._get_checkpoint_name())),
                global_step=global_step)

        test_ppl = None
        if use_test_set:
            dev_eval_iterator_feed_dict = {
                eval_model.eval_file_placeholder: self.config.test_data
            }
            eval_sess.run(eval_model.iterator.initializer,
                          feed_dict=dev_eval_iterator_feed_dict)
            test_ppl = model_helper.compute_perplexity(loaded_eval_model,
                                                       eval_sess, "test")
            log.add_summary(summary_writer, global_step, "test_ppl", test_ppl)

        return dev_ppl, test_ppl
Ejemplo n.º 11
0
def run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                      src_data, tgt_data):
    """
    Sample decode a random sentence from the source data. Used to print the tangible progress of the model.
    :param infer_model: The model used to produce the response.
    :param model_dir: directory which contains the trained model
    :param summary_writer: An instance of a tensorflow Summary writer
    :return:
    """
    with infer_model.graph.as_default():
        # Load the model from checkpoint. It automatically loads the latest checkpoint
        loaded_infer_model, global_step = model_helper.create_or_load_model(
            model=infer_model.model,
            model_dir=model_dir,
            session=infer_sess,
            name="infer"
        )
        _sample_decode(model=loaded_infer_model, global_step=global_step, sess=infer_sess, hparams=hparams,
                       iterator=infer_model.iterator, src_data=src_data, tgt_data=tgt_data,
                       iterator_src_placeholder=infer_model.src_placeholder,
                       iterator_batch_size_placeholder=infer_model.batch_size_placeholder,
                       summary_writer=summary_writer)
Ejemplo n.º 12
0
    def run_full_eval(self, infer_model, eval_model, infer_sess, eval_sess,
                      model_dir, label, summary_writer):
        dev_ppl, test_ppl = self.run_internal_eval(eval_model,
                                                   eval_sess,
                                                   model_dir,
                                                   summary_writer,
                                                   use_test_set=True)

        with infer_model.graph.as_default():
            loaded_infer_model, _ = model_helper.create_or_load_model(
                infer_model.model, self.config.model_dir, infer_sess, "infer")
            infer_feed_dict = {
                infer_model.src_placeholder:
                self._load_data(self.config.test_data),
                infer_model.batch_size_placeholder:
                self.config.infer_batch_size,
            }
            self._decode_and_evaluate(loaded_infer_model,
                                      infer_sess,
                                      infer_feed_dict,
                                      label=label)
        return dev_ppl, test_ppl
Ejemplo n.º 13
0
    def train(self):
        hparams = self.hparams
        train_model = self.train_model
        train_sess = self.train_sess
        model_dir = self.model_dir

        steps_per_stats = hparams.steps_per_stats
        num_train_steps = hparams.num_train_steps

        summary_name = "train_log"

        # Load train model
        with self.train_model.graph.as_default():
            loaded_train_model, global_step = model_helper.create_or_load_model(
                self.train_model.model, self.model_dir, self.train_sess, "train")

        # Summary writer
        summary_writer = tf.summary.FileWriter(
            os.path.join(self.out_dir, summary_name), train_model.graph)

        # Initialize dataset iterator
        train_sess.run(
            train_model.iterator.initializer,
            feed_dict={train_model.skip_count_placeholder: 0})

        loss_track = []
        training_start_time = time.time()
        epoch_count = 0
        last_stats_step = global_step
        stats = train_utils.init_stats()
        best_bleu_score = 0

        while global_step < num_train_steps:
            # Run a training step
            start_time = time.time()
            try:
                train_result = loaded_train_model.train(train_sess)
            except tf.errors.OutOfRangeError:
                # Finished going through the training dataset. Go to next epoch.
                epoch_count += 1
                print("# Finished epoch %d, step %d." %
                      (epoch_count, global_step))

                # Save model params
                loaded_train_model.saver.save(
                    train_sess,
                    os.path.join(model_dir, "chatbot.ckpt"),
                    global_step=global_step)

                # Do evaluation
                self.eval(best_bleu_score)

                train_sess.run(
                    train_model.iterator.initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
                continue

            # Write step summary and accumulate statistics
            global_step = train_utils.update_stats(
                stats, summary_writer, start_time,
                train_result.values(), best_bleu_score)

            loss_track.append(train_result['train_loss'])

            if global_step - last_stats_step >= steps_per_stats:
                last_stats_step = global_step
                is_overflow = train_utils.check_stats(stats, global_step, steps_per_stats)
                if is_overflow:
                    break

                # Reset statistics
                stats = train_utils.init_stats()

        # Training done.
        loaded_train_model.saver.save(
            train_sess,
            os.path.join(model_dir, "chatbot.ckpt"),
            global_step=global_step)

        summary_writer.close()

        print('Training done. Total time: %.4f' % (time.time() - training_start_time))
Ejemplo n.º 14
0
    def train(self, target_session="", scope=None):
        out_dir = self.config.model_dir
        model_dir = out_dir

        num_train_steps = self.config.num_train_steps
        steps_per_stats = self.config.steps_per_stats
        # steps_per_external_eval = self.config.steps_per_external_eval
        steps_per_eval = 20 * steps_per_stats
        # if not steps_per_external_eval:
        #     steps_per_external_eval = 5 * steps_per_eval

        self._pre_model_creation()

        train_model = taware_helper.create_train_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope)
        eval_model = taware_helper.create_eval_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope)
        infer_model = taware_helper.create_infer_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope)

        # Preload data for sample decoding.
        dev_file = self.config.dev_data
        eval_data = self._load_data(dev_file, include_target=True)

        summary_name = "train_log"

        # Log and output files
        log_file = os.path.join(out_dir, "log_%d" % time.time())
        log_f = tf.gfile.GFile(log_file, mode="a")
        log.print_out("# log_file=%s" % log_file, log_f)

        avg_step_time = 0.0

        # TensorFlow model
        config_proto = models.model_helper.get_config_proto(self.config.log_device)

        train_sess = tf.Session(
            target=target_session, config=config_proto, graph=train_model.graph)
        eval_sess = tf.Session(
            target=target_session, config=config_proto, graph=eval_model.graph)
        infer_sess = tf.Session(
            target=target_session, config=config_proto, graph=infer_model.graph)

        with train_model.graph.as_default():
            loaded_train_model, global_step = model_helper.create_or_load_model(
                train_model.model, model_dir, train_sess, "train")

        # Summary writer
        summary_writer = tf.summary.FileWriter(
            os.path.join(out_dir, summary_name), train_model.graph)

        # First evaluation
        # self.run_full_eval(
        #    model_dir, infer_model, infer_sess,
        #    eval_model, eval_sess, summary_writer, eval_data)

        last_stats_step = global_step
        last_eval_step = global_step
        # last_external_eval_step = global_step
        patience = self.config.patience

        # This is the training loop.
        stats = self.init_stats()
        speed, train_ppl = 0.0, 0.0
        start_train_time = time.time()

        log.print_out(
            "# Start step %d, epoch %d, lr %g, %s" %
            (global_step, self.config.epoch, loaded_train_model.learning_rate.eval(session=train_sess),
             time.ctime()),
            log_f)

        self.config.save()
        log.print_out("# Configs saved")

        # Initialize all of the iterators
        skip_count = self.config.batch_size * self.config.epoch_step
        log.print_out("# Init train iterator for %d steps, skipping %d elements" %
                      (self.config.num_train_steps, skip_count))

        train_sess.run(
            train_model.iterator.initializer,
            feed_dict={train_model.skip_count_placeholder: skip_count})

        while self.config.epoch < self.config.num_train_epochs and patience > 0:
            ### Run a step ###
            start_time = time.time()
            try:
                step_result = loaded_train_model.train(train_sess)
                self.config.epoch_step += 1
            except tf.errors.OutOfRangeError:
                # Finished going through the training dataset.  Go to next epoch.
                sw = Stopwatch()
                log.print_out(
                    "# Finished an epoch, step %d. Perform external evaluation" %
                    global_step)
                self.run_sample_decode(infer_model, infer_sess,
                                       model_dir, summary_writer, eval_data)

                log.print_out(
                    "## Done epoch %d in %d steps. step %d @ eval time: %ds" %
                    (self.config.epoch, self.config.epoch_step, global_step, sw.elapsed()))

                self.config.epoch += 1
                self.config.epoch_step = 0
                self.config.save()

                train_sess.run(
                    train_model.iterator.initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
                continue

            # Write step summary and accumulate statistics
            global_step = self.update_stats(stats, summary_writer, start_time, step_result)

            # Once in a while, we print statistics.
            if global_step - last_stats_step >= steps_per_stats:
                last_stats_step = global_step
                train_ppl, speed, is_overflow = self.check_stats(stats, global_step, steps_per_stats, log_f)
                if is_overflow:
                    break

                # Reset statistics
                stats = self.init_stats()

            if global_step - last_eval_step >= steps_per_eval:
                last_eval_step = global_step

                log.print_out("# Save eval, global step %d" % global_step)
                log.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

                # Save checkpoint
                loaded_train_model.saver.save(
                    train_sess,
                    self.config.checkpoint_file,
                    global_step=global_step)

                # Evaluate on dev
                self.run_sample_decode(infer_model, infer_sess, model_dir, summary_writer, eval_data)
                dev_ppl, _ = self.run_internal_eval(eval_model, eval_sess, model_dir, summary_writer, use_test_set=False)

                if dev_ppl < self.config.best_dev_ppl:
                    self.config.best_dev_ppl = dev_ppl
                    patience = self.config.patience
                    log.print_out('    ** Best model thus far, ep {}|{} dev_ppl {:.3f}'.format(
                        self.config.epoch,
                        self.config.epoch_step,
                        dev_ppl))
                elif dev_ppl > self.config.degrade_threshold * self.config.best_dev_ppl:
                    patience -= 1
                    log.print_out(
                        '    worsened, ep {}|{} patience {} best_dev_ppl {:.3f}'.format(
                            self.config.epoch,
                            self.config.epoch_step,
                            self.config.patience,
                            self.config.best_dev_ppl))

                # Save config parameters
                self.config.save()

            # if global_step - last_external_eval_step >= steps_per_external_eval:
            #     last_external_eval_step = global_step
            #
            #     # Save checkpoint
            #     loaded_train_model.saver.save(
            #         train_sess,
            #         self.config.checkpoint_file,
            #         global_step=global_step)
            #     self.run_sample_decode(infer_model, infer_sess,
            #                            model_dir, summary_writer, eval_data)
                # dev_scores, test_scores, _ = self.run_external_eval(infer_model, infer_sess, model_dir, summary_writer)

        # Done training
        loaded_train_model.saver.save(
            train_sess,
            self.config.checkpoint_file,
            global_step=global_step)

        # result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = self.run_full_eval(
        #     model_dir, infer_model, infer_sess,
        #     eval_model, eval_sess,
        #     summary_writer, eval_data)
        dev_scores, test_scores, dev_ppl, test_ppl = None, None, None, None
        result_summary = ""

        log.print_out(
            "# Final, step %d lr %g "
            "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
            (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
             avg_step_time, speed, train_ppl, result_summary, time.ctime()),
            log_f)
        log.print_time("# Done training!", start_train_time)

        summary_writer.close()

        # log.print_out("# Start evaluating saved best models.")
        # for metric in self.config.metrics:
        #     best_model_dir = getattr(self.config, "best_" + metric + "_dir")
        #     summary_writer = tf.summary.FileWriter(
        #         os.path.join(best_model_dir, summary_name), infer_model.graph)
        #     result_summary, best_global_step, _, _, _, _ = self.run_full_eval(
        #         best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
        #         summary_writer, eval_data)
        #     log.print_out("# Best %s, step %d "
        #                   "step-time %.2f wps %.2fK, %s, %s" %
        #                   (metric, best_global_step, avg_step_time, speed,
        #                    result_summary, time.ctime()), log_f)
        #     summary_writer.close()

        return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Ejemplo n.º 15
0
    def train(self, target_session="", scope=None):
        assert self.config.num_turns >= 2
        if self.config.is_pretrain_enabled():
            assert self.config.num_pretrain_turns >= 2
            assert self.config.num_turns >= self.config.num_pretrain_turns

        out_dir = self.config.model_dir

        steps_per_stats = self.config.steps_per_stats
        steps_per_eval = 20 * steps_per_stats

        _helper = self._get_model_helper()

        self._pre_model_creation()

        train_model = _helper.create_train_model(self.config, scope)
        eval_model = _helper.create_eval_model(self.config, scope)
        infer_model = _helper.create_infer_model(self.config, scope)

        self._post_model_creation(train_model, eval_model, infer_model)

        # Preload data for sample decoding.
        dev_file = self.config.dev_data
        eval_data = self._load_data(dev_file, include_target=True)

        summary_name = "train_log"

        # Log and output files
        log_file = os.path.join(out_dir, "log_%d" % time.time())
        log_f = tf.gfile.GFile(log_file, mode="a")
        log.print_out("# log_file=%s" % log_file, log_f)

        self.config.save()
        log.print_out("# Configs saved")

        avg_step_time = 0.0

        # TensorFlow model
        config_proto = model_helper.get_config_proto(self.config.log_device)

        train_sess = tf.Session(
            target=target_session, config=config_proto, graph=train_model.graph)
        eval_sess = tf.Session(
            target=target_session, config=config_proto, graph=eval_model.graph)
        infer_sess = tf.Session(
            target=target_session, config=config_proto, graph=infer_model.graph)

        # Pretraining
        num_pretrain_steps = 0
        if self.config.is_pretrain_enabled():
            num_pretrain_steps = self.config.num_pretrain_steps

            pretrain_model = _helper.create_pretrain_model(self.config, scope)

            with tf.Session(
                    target=target_session, config=config_proto, graph=pretrain_model.graph) as pretrain_sess:
                self.pretrain(pretrain_sess, pretrain_model, log_f)

        with train_model.graph.as_default():
            loaded_train_model, global_step = model_helper.create_or_load_model(
                train_model.model, self.config.model_dir, train_sess, "train")

        # Summary writer
        summary_writer = tf.summary.FileWriter(
            os.path.join(out_dir, summary_name), train_model.graph)

        last_stats_step = global_step
        last_eval_step = global_step
        patience = self.config.patience

        stats = self.init_stats()
        speed, train_ppl = 0.0, 0.0
        start_train_time = time.time()

        log.print_out(
            "# Start step %d, epoch %d, lr %g, %s" %
            (global_step, self.config.epoch, loaded_train_model.learning_rate.eval(session=train_sess),
             time.ctime()),
            log_f)

        # Initialize all of the iterators
        skip_count = self.config.batch_size * self.config.epoch_step
        log.print_out("# Init train iterator for %d steps, skipping %d elements" %
                      (self.config.num_train_steps, skip_count))

        train_sess.run(
            train_model.iterator.initializer,
            feed_dict={train_model.skip_count_placeholder: skip_count})

        while self.config.epoch < self.config.num_train_epochs and patience > 0:

            ### Run a step ###
            start_time = time.time()
            try:
                step_result = loaded_train_model.train(train_sess)
                self.config.epoch_step += 1
            except tf.errors.OutOfRangeError:
                # Finished going through the training dataset.  Go to next epoch.
                sw = Stopwatch()
                self.run_sample_decode(infer_model, infer_sess,
                                       self.config.model_dir, summary_writer, eval_data)
                # if self.config.enable_epoch_evals:
                #     dev_ppl, test_ppl = self.run_full_eval(infer_model, eval_model,
                #                                            infer_sess, eval_sess,
                #                                            out_dir,
                #                                            fs.file_name(self.config.test_data) + '_' + global_step,
                #                                            summary_writer)
                #     log.print_out(
                #         "%% done epoch %d #%d  step %d - dev_ppl: %.2f test_ppl: %.2f @ eval time: %ds" %
                #         (self.config.epoch, self.config.epoch_step, global_step, dev_ppl, test_ppl, sw.elapsed()))
                # else:
                log.print_out(
                    "## Done epoch %d in %d steps. step %d @ eval time: %ds" %
                    (self.config.epoch, self.config.epoch_step, global_step, sw.elapsed()))

                self.config.epoch += 1
                self.config.epoch_step = 0
                self.config.save()

                train_sess.run(
                    train_model.iterator.initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
                continue

            # Write step summary and accumulate statistics
            global_step = self.update_stats(stats, summary_writer, start_time, step_result)

            # Once in a while, we print statistics.
            if global_step - last_stats_step >= steps_per_stats:
                last_stats_step = global_step
                train_ppl, speed, is_overflow = self.check_stats(stats, global_step, steps_per_stats, log_f)
                if is_overflow:
                    break

                # Reset statistics
                stats = self.init_stats()

            if global_step - last_eval_step >= steps_per_eval:
                last_eval_step = global_step

                log.print_out("# Save eval, global step %d" % global_step)
                log.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

                # Save checkpoint
                loaded_train_model.saver.save(train_sess,
                                              self.config.checkpoint_file,
                                              global_step=global_step)

                # Evaluate on dev
                self.run_sample_decode(infer_model, infer_sess, out_dir, summary_writer, eval_data)
                dev_ppl, _ = self.run_internal_eval(eval_model, eval_sess, out_dir, summary_writer,
                                                    use_test_set=False)
                if dev_ppl < self.config.best_dev_ppl:
                    self.config.best_dev_ppl = dev_ppl
                    patience = self.config.patience
                    log.print_out('    ** Best model thus far, ep {}|{} dev_ppl {:.3f}'.format(
                        self.config.epoch,
                        self.config.epoch_step,
                        dev_ppl))
                elif dev_ppl > self.config.degrade_threshold * self.config.best_dev_ppl:
                    patience -= 1
                    log.print_out(
                        '    worsened, ep {}|{} patience {} best_dev_ppl {:.3f}'.format(
                            self.config.epoch,
                            self.config.epoch_step,
                            patience,
                            self.config.best_dev_ppl))

                # Save config parameters
                self.config.save()

        # Done training
        loaded_train_model.saver.save(
            train_sess,
            self.config.checkpoint_file,
            global_step=global_step)

        if self.config.enable_final_eval:
            dev_ppl, test_ppl = self.run_full_eval(infer_model, eval_model,
                                                   infer_sess, eval_sess,
                                                   out_dir,
                                                   fs.file_name(self.config.test_data) + '_final',
                                                   summary_writer)

            log.print_out(
                "# Final, step %d ep %d/%d lr %g "
                "step-time %.2f wps %.2fK train_ppl %.2f, dev_ppl %.2f, test_ppl %.2f, %s" %
                (global_step, self.config.epoch, self.config.epoch_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, dev_ppl, test_ppl, time.ctime()),
                log_f)
        else:
            log.print_out(
                "# Final, step %d ep %d/%d lr %g "
                "step-time %.2f wps %.2fK train_ppl %.2f best_dev_ppl %.2f, %s" %
                (global_step, self.config.epoch, self.config.epoch_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, self.config.best_dev_ppl, time.ctime()),
                log_f)

        log.print_time("# Done training!", start_train_time)

        summary_writer.close()

        eval_sess.close()
        infer_sess.close()
        train_sess.close()
Ejemplo n.º 16
0
    def infer(self, num_print_per_batch=0):
        model_dir = self.model_dir
        out_dir = self.out_dir
        dev_src_file = self.dev_src_file
        dev_tgt_file = self.dev_tgt_file
        infer_batch_size = self.hparams.infer_batch_size
        beam_width = self.hparams.beam_width
        infer_model = self.infer_model
        infer_sess = self.infer_sess

        infer_output_file = os.path.join(out_dir, 'infer_output')

        start_time = time.time()
        print('# Decoding to %s' % infer_output_file)

        # Load infer model
        with infer_model.graph.as_default():
            loaded_infer_model, global_step = model_helper.create_or_load_model(
                infer_model.model, model_dir, infer_sess, "infer")

        with open(dev_src_file, encoding='utf-8') as in_src_file, \
                open(dev_tgt_file, encoding='utf-8') as in_tgt_file, \
                open(infer_output_file, mode='w', encoding='utf-8') as out_file:
            infer_src_data = in_src_file.readlines()
            infer_tgt_data = in_tgt_file.readlines()

            iterator_feed_dict = {
                infer_model.src_data_placeholder: infer_src_data,
                infer_model.batch_size_placeholder: infer_batch_size
            }
            infer_sess.run(
                infer_model.iterator.initializer,
                feed_dict=iterator_feed_dict)

            num_sentences = 0
            while True:
                try:
                    # The shape of sample_words is [batch_size, time] or
                    # [beam_width, batch_size, time] when using beam search.
                    sample_words = loaded_infer_model.decode(infer_sess)

                    if beam_width == 0:
                        sample_words = np.expand_dims(sample_words, 0)

                    batch_size = sample_words.shape[1]

                    for sent_id in range(batch_size):
                        beam_id = random.randint(0, beam_width - 1) if beam_width > 0 else 0
                        response = self._get_response(sample_words[beam_id][sent_id])
                        out_file.write(response + '\n')

                        if sent_id < num_print_per_batch:
                            sent_id += num_sentences
                            print("  sentence %d" % sent_id)
                            print("  src: %s" % infer_src_data[sent_id], end='')
                            print("  ref: %s" % infer_tgt_data[sent_id], end='')
                            print("  bot: %s" % response)

                    num_sentences += batch_size
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d, beam width %d" %
                        (num_sentences, beam_width), start_time)
                    break
Ejemplo n.º 17
0
def train(hparams, scope=None, target_session=''):
    """Train the chatbot"""
    # Initialize some local hyperparameters
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hparams.architecture == "simple":
        model_creator = SimpleModel
        get_infer_iterator = iterator_utils.get_infer_iterator
        get_iterator = iterator_utils.get_iterator
    elif hparams.architecture == "hier":
        model_creator = HierarchicalModel
        # Parse some of the arguments now
        def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse,
                       eos, src_max_len):
            return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos,
                                                      src_max_len=src_max_len, eou=hparams.eou,
                                                      dialogue_max_len=hparams.dialogue_max_len)
        get_infer_iterator = curry_get_infer_iterator

        def curry_get_iterator(src_dataset,
                 tgt_dataset,
                 vocab_table,
                 batch_size,
                 sos,
                 eos,
                 src_reverse,
                 random_seed,
                 num_buckets,
                 src_max_len=None,
                 tgt_max_len=None,
                 num_threads=4,
                 output_buffer_size=None,
                 skip_count=None):
            return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos,
                                                eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed,
                                                num_dialogue_buckets=num_buckets, src_max_len=src_max_len,
                                                tgt_max_len=tgt_max_len, num_threads=num_threads,
                                                output_buffer_size=output_buffer_size, skip_count=skip_count)

        get_iterator = curry_get_iterator
    else:
        raise ValueError("Unkown architecture", hparams.architecture)

    # Create three models which share parameters through the use of checkpoints
    train_model = create_train_model(model_creator, get_iterator, hparams, scope)
    eval_model = create_eval_model(model_creator, get_iterator, hparams, scope)
    infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope)
    # ToDo: adapt for architectures
    # Preload the data to use for sample decoding

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # Create the configurations for the sessions
    config_proto = utils.get_config_proto(log_device_placement=log_device_placement)
    # Create three sessions, one for each model
    train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph)
    eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph)

    # Load the train model from checkpoint or create a new one
    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir,
                                                                            train_sess, name="train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, summary_name), train_model.graph)
    # First evaluation
    run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    # Initialize the hyperparameters for the loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         time.ctime()),
        log_f)

    # epoch_step records where we were within an epoch. Used to skip trained on examples
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    # Initialize the training iterator
    train_sess.run(
        train_model.iterator.initializer,
        feed_dict={train_model.skip_count_placeholder: skip_count})

    # Train until we reach num_steps.
    while global_step < num_train_steps:
        # Run a step
        start_step_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,  # The _ is the output of the update op
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            # Decode and print a random sentence
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Perform external evaluation to save checkpoints if this is the best for some metric
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)
            # Reinitialize the iterator from the beginning
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_step_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                # The model has screwed up
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            # Perform evaluation. Start by reassigning the last_eval_step variable to the current step
            last_eval_step = global_step
            # Print the progress and add summary
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method.
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            # Run the external evaluation
            last_external_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step.
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)

    # Done training. Save the model
    loaded_train_model.saver.save(
        train_sess,
        os.path.join(out_dir, "chatbot.ckpt"),
        global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()),
        log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out("# Best %s, step %d "
                        "step-time %.2f wps %.2fK, %s, %s" %
                        (metric, best_global_step, avg_step_time, speed,
                         result_summary, time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)