def load_model(model_dir, params=None):
  """Loads a model class from a given directory
  """

  train_options = training_utils.TrainOptions.load(model_dir)

  # Load vocabulary
  source_vocab_info = vocab.get_vocab_info(train_options.source_vocab_path)
  target_vocab_info = vocab.get_vocab_info(train_options.target_vocab_path)

  # Find model class
  model_class = getattr(models, train_options.model_class)

  # Parse parameter and merge with defaults
  hparams = model_class.default_params()
  hparams.update(train_options.hparams)

  if params is not None:
    hparams.update(params)

  training_utils.print_hparams(hparams)

  # Create model instance
  model = model_class(
      source_vocab_info=source_vocab_info,
      target_vocab_info=target_vocab_info,
      params=hparams)

  return model
示例#2
0
    def __init__(self, params, mode, name):

        super(Seq2SeqModel, self).__init__(params, mode, name)

        self._source_embedding = None
        self._source_emb_scope = None
        self._target_embedding = None
        self._target_emb_scope = None
        self.source_vocab_info = None

        if "vocab_source" in self.params and self.params["vocab_source"]:
            self.source_vocab_info = vocab.get_vocab_info(
                self.params["vocab_source"])
        if "pos_source" in self.params and self.params["pos_source"]:
            self.source_pos_info = vocab.get_vocab_info(
                self.params["pos_source"])
        if "ner_source" in self.params and self.params["ner_source"]:
            self.source_ner_info = vocab.get_vocab_info(
                self.params["ner_source"])

        self.target_vocab_info = None
        if "vocab_target" in self.params and self.params["vocab_target"]:
            self.target_vocab_info = vocab.get_vocab_info(
                self.params["vocab_target"])
        if "ner_target" in self.params and self.params["ner_target"]:
            self.target_ner_info = vocab.get_vocab_info(
                self.params["ner_target"])
示例#3
0
  def __init__(self, params, mode, name):
    super(Seq2SeqModel, self).__init__(params, mode, name)

    self.source_vocab_info = None
    if "vocab_source" in self.params and self.params["vocab_source"]:
      self.source_vocab_info = vocab.get_vocab_info(self.params["vocab_source"])

    self.target_vocab_info = None
    if "vocab_target" in self.params and self.params["vocab_target"]:
      self.target_vocab_info = vocab.get_vocab_info(self.params["vocab_target"])
示例#4
0
  def __init__(self, params, mode, name):
    super(Seq2SeqModel, self).__init__(params, mode, name)

    self.source_vocab_info = None
    if "vocab_source" in self.params and self.params["vocab_source"]:
      self.source_vocab_info = vocab.get_vocab_info(self.params["vocab_source"])

    self.target_vocab_info = None
    if "vocab_target" in self.params and self.params["vocab_target"]:
      self.target_vocab_info = vocab.get_vocab_info(self.params["vocab_target"])
示例#5
0
 def test_vocab_info(self):
   vocab_info = vocab.get_vocab_info(self.vocab_file.name)
   self.assertEqual(vocab_info.vocab_size, 3)
   self.assertEqual(vocab_info.path, self.vocab_file.name)
   self.assertEqual(vocab_info.special_vocab.UNK, 3)
   self.assertEqual(vocab_info.special_vocab.SEQUENCE_START, 4)
   self.assertEqual(vocab_info.special_vocab.SEQUENCE_END, 5)
   self.assertEqual(vocab_info.total_size, 6)
示例#6
0
 def test_vocab_info(self):
     vocab_info = vocab.get_vocab_info(self.vocab_file.name)
     self.assertEqual(vocab_info.vocab_size, 3)
     self.assertEqual(vocab_info.path, self.vocab_file.name)
     self.assertEqual(vocab_info.special_vocab.UNK, 3)
     self.assertEqual(vocab_info.special_vocab.SEQUENCE_START, 4)
     self.assertEqual(vocab_info.special_vocab.SEQUENCE_END, 5)
     self.assertEqual(vocab_info.total_size, 6)
示例#7
0
 def __init__(self, params, mode, name="att_seq2seq"):
     super(AttentionSeq2Seq, self).__init__(params, mode, name)
     # add candidate answer part
     self.source_candidate_vocab_info = None
     if "vocab_source_candidate" in self.params and self.params[
             "vocab_source_candidate"]:
         self.source_candidate_vocab_info = vocab.get_vocab_info(
             self.params["vocab_source_candidate"])
示例#8
0
 def __init__(self, params, mode, name="basic_biseq2seq"):
     super(BasicBiSeq2Seq, self).__init__(params, mode, name)
     # add candidate answer part
     self.source_candidate_vocab_info = None
     if "vocab_source_candidate" in self.params and self.params[
             "vocab_source_candidate"]:
         self.source_candidate_vocab_info = vocab.get_vocab_info(
             self.params["vocab_source_candidate"])
     self.encoder_class = locate(self.params["encoder.class"])
     self.decoder_class = locate(self.params["decoder.class"])
示例#9
0
  def setUp(self):
    super(EncoderDecoderTests, self).setUp()
    tf.logging.set_verbosity(tf.logging.INFO)
    self.batch_size = 2
    self.input_depth = 4
    self.sequence_length = 10

    # Create vocabulary
    self.vocab_list = [str(_) for _ in range(10)]
    self.vocab_list += ["笑う", "泣く", "了解", "はい", "^_^"]
    self.vocab_size = len(self.vocab_list)
    self.vocab_file = test_utils.create_temporary_vocab_file(self.vocab_list)
    self.vocab_info = vocab.get_vocab_info(self.vocab_file.name)
示例#10
0
    def __init__(self, params, mode, name):
        super(Seq2SeqModel, self).__init__(params, mode, name)

        self.source_vocab_info = None
        if "vocab_source" in self.params and self.params["vocab_source"]:
            self.source_vocab_info = vocab.get_vocab_info(
                self.params["vocab_source"])

        self.target_vocab_info = None
        if "vocab_target" in self.params and self.params["vocab_target"]:
            self.target_vocab_info = vocab.get_vocab_info(
                self.params["vocab_target"])

        self.embedding_mat_source = None
        self.embedding_mat_target = None
        # added for pretrain
        if "embedding.file" in self.params and self.params["embedding.file"]:
            self.embedding_mat_source = read_embeddings(
                self.params['embedding.file'], self.source_vocab_info.path,
                self.params["embedding.dim"], "source")
            self.embedding_mat_target = read_embeddings(
                self.params['embedding.file'], self.target_vocab_info.path,
                self.params["embedding.dim"], "target")
示例#11
0
  def setUp(self):
    super(EncoderDecoderTests, self).setUp()
    tf.logging.set_verbosity(tf.logging.INFO)
    self.batch_size = 2
    self.input_depth = 4
    self.sequence_length = 10

    # Create vocabulary
    self.vocab_list = [str(_) for _ in range(10)]
    self.vocab_list += ["笑う", "泣く", "了解", "はい", "^_^"]
    self.vocab_size = len(self.vocab_list)
    self.vocab_file = test_utils.create_temporary_vocab_file(self.vocab_list)
    self.vocab_info = vocab.get_vocab_info(self.vocab_file.name)

    tf.contrib.framework.get_or_create_global_step()
示例#12
0
def test_copy_gen_model(source_path=None, target_path=None, vocab_path=None):

    tf.logging.set_verbosity(tf.logging.INFO)
    batch_size = 2
    input_depth = 4
    sequence_length = 10

    if vocab_path is None:
        # Create vocabulary
        vocab_list = [str(_) for _ in range(10)]
        vocab_list += ["笑う", "泣く", "了解", "はい", "^_^"]
        vocab_size = len(vocab_list)
        vocab_file = test_utils.create_temporary_vocab_file(vocab_list)
        vocab_info = vocab.get_vocab_info(vocab_file.name)
        vocab_path = vocab_file.name
        tf.logging.info(vocab_file.name)
    else:
        vocab_info = vocab.get_vocab_info(vocab_path)
        vocab_list = get_vocab_list(vocab_path)

    extend_vocab = vocab_list + ["中国", "爱", "你"]

    tf.contrib.framework.get_or_create_global_step()
    source_len = sequence_length + 5
    target_len = sequence_length + 10
    source = " ".join(np.random.choice(extend_vocab, source_len))
    target = " ".join(np.random.choice(extend_vocab, target_len))

    is_tmp_file = False
    if source_path is None and target_path is None:
        is_tmp_file = True
        sources_file, targets_file = test_utils.create_temp_parallel_data(
            sources=[source], targets=[target])
        source_path = sources_file.name
        target_path = targets_file.name

    # Build model graph
    mode = tf.contrib.learn.ModeKeys.TRAIN
    params_ = CopyGenSeq2Seq.default_params().copy()
    params_.update({
        "vocab_source": vocab_path,
        "vocab_target": vocab_path,
    })
    model = CopyGenSeq2Seq(params=params_, mode=mode)

    tf.logging.info(source_path)
    tf.logging.info(target_path)

    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={
        "source_files": [source_path],
        "target_files": [target_path]
    },
                                                               mode=mode)
    input_fn = training_utils.create_input_fn(pipeline=input_pipeline_,
                                              batch_size=batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())
        with tf.contrib.slim.queues.QueueRunners(sess):
            fetches_ = sess.run(fetches)

    if is_tmp_file:
        sources_file.close()
        targets_file.close()

    return model, fetches_
示例#13
0
def create_experiment(output_dir):
    """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """

    config = run_config.RunConfig(
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_secs=FLAGS.save_checkpoints_secs,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)

    # Load vocabulary info
    source_vocab_info = vocab.get_vocab_info(FLAGS.vocab_source)
    target_vocab_info = vocab.get_vocab_info(FLAGS.vocab_target)

    # Find model class
    model_class = getattr(models, FLAGS.model)

    # Parse parameter and merge with defaults
    hparams = model_class.default_params()
    if FLAGS.hparams is not None and isinstance(FLAGS.hparams, str):
        hparams = HParamsParser(hparams).parse(FLAGS.hparams)
    elif isinstance(FLAGS.hparams, dict):
        hparams.update(FLAGS.hparams)

    # Print hparams
    training_utils.print_hparams(hparams)

    # One the main worker, save training options and vocabulary
    if config.is_chief:
        # Copy vocabulary to output directory
        gfile.MakeDirs(output_dir)
        source_vocab_path = os.path.join(output_dir, "vocab_source")
        gfile.Copy(FLAGS.vocab_source, source_vocab_path, overwrite=True)
        target_vocab_path = os.path.join(output_dir, "vocab_target")
        gfile.Copy(FLAGS.vocab_target, target_vocab_path, overwrite=True)
        # Save train options
        train_options = training_utils.TrainOptions(
            hparams=hparams,
            model_class=FLAGS.model,
            source_vocab_path=source_vocab_path,
            target_vocab_path=target_vocab_path)
        train_options.dump(output_dir)

    # Create model
    model = model_class(source_vocab_info=source_vocab_info,
                        target_vocab_info=target_vocab_info,
                        params=hparams)

    bucket_boundaries = None
    if FLAGS.buckets:
        bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

    # Create training input function
    train_input_fn = training_utils.create_input_fn(
        data_provider_fn=functools.partial(
            data_utils.make_parallel_data_provider,
            data_sources_source=FLAGS.train_source,
            data_sources_target=FLAGS.train_target,
            shuffle=True,
            num_epochs=FLAGS.train_epochs,
            delimiter=FLAGS.delimiter),
        batch_size=FLAGS.batch_size,
        bucket_boundaries=bucket_boundaries)

    # Create eval input function
    eval_input_fn = training_utils.create_input_fn(
        data_provider_fn=functools.partial(
            data_utils.make_parallel_data_provider,
            data_sources_source=FLAGS.dev_source,
            data_sources_target=FLAGS.dev_target,
            shuffle=False,
            num_epochs=1,
            delimiter=FLAGS.delimiter),
        batch_size=FLAGS.batch_size)

    def model_fn(features, labels, params, mode):
        """Builds the model graph"""
        return model(features, labels, params, mode)

    estimator = tf.contrib.learn.estimator.Estimator(model_fn=model_fn,
                                                     model_dir=output_dir,
                                                     config=config)

    train_hooks = training_utils.create_default_training_hooks(
        estimator=estimator,
        sample_frequency=FLAGS.sample_every_n_steps,
        delimiter=FLAGS.delimiter)

    eval_metrics = {
        "log_perplexity": metrics.streaming_log_perplexity(),
        "bleu": metrics.make_bleu_metric_spec(),
    }

    experiment = tf.contrib.learn.experiment.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        min_eval_frequency=FLAGS.eval_every_n_steps,
        train_steps=FLAGS.train_steps,
        eval_steps=None,
        eval_metrics=eval_metrics,
        train_monitors=train_hooks)

    return experiment