コード例 #1
0
    def train(self, sess, progress, summary_writer):
        heading = lambda s: utils.heading(s, '(' + self._config.model_name +
                                          ')')
        trained_on_sentences = 0
        start_time = time.time()
        unsupervised_loss_total, unsupervised_loss_count = 0, 0
        supervised_loss_total, supervised_loss_count = 0, 0
        for mb in self._get_training_mbs(progress.unlabeled_data_reader):
            if mb.task_name != 'unlabeled':
                loss = self._model.train_labeled(sess, mb)
                print('train loss', loss)
                supervised_loss_total += loss
                supervised_loss_count += 1

            if mb.task_name == 'unlabeled':
                self._model.run_teacher(sess, mb)
                loss = self._model.train_unlabeled(sess, mb)
                unsupervised_loss_total += loss
                unsupervised_loss_count += 1
                mb.teacher_predictions.clear()

            trained_on_sentences += mb.size
            global_step = self._model.get_global_step(sess)

            if global_step % self._config.print_every == 0:
                supervised_loss_reported = supervised_loss_total / max(
                    1, supervised_loss_count)
                utils.log(
                    'step {:} - '
                    'supervised loss: {:.3f} - '
                    'unsupervised loss: {:.3f} - '
                    '{:.1f} sentences per second'.format(
                        global_step, supervised_loss_reported,
                        unsupervised_loss_total /
                        max(1, unsupervised_loss_count),
                        trained_on_sentences / (time.time() - start_time)))
                unsupervised_loss_total, unsupervised_loss_count = 0, 0
                supervised_loss_total, supervised_loss_count = 0, 0
                summary_writer.add_summary(
                    tf.Summary(value=[
                        tf.Summary.Value(tag='loss',
                                         simple_value=supervised_loss_reported)
                    ]), global_step)

            if global_step % self._config.eval_dev_every == 0:
                heading('EVAL ON DEV')
                self.evaluate_all_tasks(sess, summary_writer, progress.history)
                progress.save_if_best_dev_model(sess, global_step)
                utils.log()

            if global_step % self._config.eval_train_every == 0:
                heading('EVAL ON TRAIN')
                self.evaluate_all_tasks(sess, summary_writer, progress.history,
                                        True)
                utils.log()

            if global_step % self._config.save_model_every == 0:
                heading('CHECKPOINTING MODEL')
                progress.write(sess, global_step)
                utils.log()
コード例 #2
0
def main():
    utils.heading('SETUP')
    config = configure.Config(mode=FLAGS.mode, model_name=FLAGS.model_name)
    config.write()
    with tf.Graph().as_default() as graph:
        model_trainer = trainer.Trainer(config)
        summary_writer = tf.summary.FileWriter(config.summaries_dir)
        checkpoints_saver = tf.train.Saver(max_to_keep=1)
        best_model_saver = tf.train.Saver(max_to_keep=1)
        init_op = tf.global_variables_initializer()
        graph.finalize()
        with tf.Session() as sess:
            sess.run(init_op)
            progress = training_progress.TrainingProgress(
                config, sess, checkpoints_saver, best_model_saver,
                config.mode == 'train')
            utils.log()
            if config.mode == 'train':
                utils.heading('START TRAINING ({:})'.format(config.model_name))
                model_trainer.train(sess, progress, summary_writer)
            elif config.mode == 'eval':
                utils.heading('RUN EVALUATION ({:})'.format(config.model_name))
                progress.best_model_saver.restore(
                    sess, tf.train.latest_checkpoint(config.checkpoints_dir))
                model_trainer.evaluate_all_tasks(sess, summary_writer, None)
            else:
                raise ValueError('Mode must be "train" or "eval"')
コード例 #3
0
ファイル: cvt.py プロジェクト: hominhtri1/CVT
def main():
  utils.heading('SETUP')
  config = configure.Config(mode=FLAGS.mode, model_name=FLAGS.model_name)
  config.write()
  if config.mode == 'encode':
    word_vocab = embeddings.get_word_vocab(config)
    sentence = "Squirrels , for example , would show up , look for the peanut , go away .".split()
    sentence = ([word_vocab[embeddings.normalize_word(w)] for w in sentence])
    print(sentence)
    return
  if config.mode == 'decode':
    word_vocab_reversed = embeddings.get_word_vocab_reversed(config)
    sentence = "25709 33 42 879 33 86 304 92 33 676 42 32 13406 33 273 445 34".split()
    sentence = ([word_vocab_reversed[int(w)] for w in sentence])
    print(sentence)
    return
  if config.mode == 'encode-vi':
    word_vocab_vi = embeddings.get_word_vocab_vi(config)
    print(len(word_vocab_vi))
    sentence = "Mỗi_một khoa_học_gia đều thuộc một nhóm nghiên_cứu , và mỗi nhóm đều nghiên_cứu rất nhiều đề_tài đa_dạng .".split()
    sentence = ([word_vocab_vi[embeddings.normalize_word(w)] for w in sentence])
    print(sentence)
    return
  if config.mode == 'decode-vi':
    word_vocab_reversed_vi = embeddings.get_word_vocab_reversed_vi(config)
    sentence = "8976 32085 129 178 17 261 381 5 7 195 261 129 381 60 37 2474 1903 6".split()
    sentence = ([word_vocab_reversed_vi[int(w)] for w in sentence])
    print(sentence)
    return
  if config.mode == 'embed':
    word_embeddings = embeddings.get_word_embeddings(config)
    word = 50
    embed = word_embeddings[word]
    print(' '.join(str(x) for x in embed))
    return
  if config.mode == 'embed-vi':
    word_embeddings_vi = embeddings.get_word_embeddings_vi(config)
    word = 50
    embed = word_embeddings_vi[word]
    print(' '.join(str(x) for x in embed))
    return
  with tf.Graph().as_default() as graph:
    model_trainer = trainer.Trainer(config)
    summary_writer = tf.summary.FileWriter(config.summaries_dir)
    checkpoints_saver = tf.train.Saver(max_to_keep=1)
    best_model_saver = tf.train.Saver(max_to_keep=1)
    init_op = tf.global_variables_initializer()
    graph.finalize()
    with tf.Session() as sess:
      sess.run(init_op)
      progress = training_progress.TrainingProgress(
          config, sess, checkpoints_saver, best_model_saver,
          config.mode == 'train')
      utils.log()
      if config.mode == 'train':
        #summary_writer.add_graph(sess.graph)
        utils.heading('START TRAINING ({:})'.format(config.model_name))
        model_trainer.train(sess, progress, summary_writer)
      elif config.mode == 'eval-train':
        utils.heading('RUN EVALUATION ({:})'.format(config.model_name))
        progress.best_model_saver.restore(sess, tf.train.latest_checkpoint(
            config.checkpoints_dir))
        model_trainer.evaluate_all_tasks(sess, summary_writer, None, train_set=True)
      elif config.mode == 'eval-dev':
        utils.heading('RUN EVALUATION ({:})'.format(config.model_name))
        progress.best_model_saver.restore(sess, tf.train.latest_checkpoint(
            config.checkpoints_dir))
        model_trainer.evaluate_all_tasks(sess, summary_writer, None, train_set=False)
      elif config.mode == 'infer':
        utils.heading('START INFER ({:})'.format(config.model_name))
        progress.best_model_saver.restore(sess, tf.train.latest_checkpoint(
            config.checkpoints_dir))
        model_trainer.infer(sess)
      elif config.mode == 'translate':
        utils.heading('START TRANSLATE ({:})'.format(config.model_name))
        progress.best_model_saver.restore(sess, tf.train.latest_checkpoint(
          config.checkpoints_dir))
        model_trainer.translate(sess)
      elif config.mode == 'eval-translate-train':
        utils.heading('RUN EVALUATION ({:})'.format(config.model_name))
        progress.best_model_saver.restore(sess, tf.train.latest_checkpoint(
          config.checkpoints_dir))
        model_trainer.evaluate_all_tasks(sess, summary_writer, None, train_set=True, is_translate=True)
      elif config.mode == 'eval-translate-dev':
        utils.heading('RUN EVALUATION ({:})'.format(config.model_name))
        progress.best_model_saver.restore(sess, tf.train.latest_checkpoint(
          config.checkpoints_dir))
        model_trainer.evaluate_all_tasks(sess, summary_writer, None, train_set=False, is_translate=True)
      else:
        raise ValueError('Mode must be "train" or "eval"')