示例#1
0
  def __prepare_model(self, params):
    """Prepare G2P model for training."""

    self.params = params

    self.session = tf.Session()

    # Prepare model.
    print("Creating model with parameters:")
    print(params)
    if self.mode == 'g2p':
      self.model = seq2seq_model.Seq2SeqModel(len(self.gr_vocab),
                                              len(self.ph_vocab), self._BUCKETS,
                                              self.params.size,
                                              self.params.num_layers,
                                              self.params.max_gradient_norm,
                                              self.params.batch_size,
                                              self.params.learning_rate,
                                              self.params.lr_decay_factor,
                                              forward_only=False,
                                              optimizer=self.params.optimizer)
    else:
      self.model = seq2seq_model.Seq2SeqModel(len(self.ph_vocab),
                                              len(self.gr_vocab), self._BUCKETS,
                                              self.params.size,
                                              self.params.num_layers,
                                              self.params.max_gradient_norm,
                                              self.params.batch_size,
                                              self.params.learning_rate,
                                              self.params.lr_decay_factor,
                                              forward_only=False,
                                              optimizer=self.params.optimizer)
    self.model.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
示例#2
0
  def load_decode_model(self):
    """Load G2P model and initialize or load parameters in session."""
    if not os.path.exists(os.path.join(self.model_dir, 'checkpoint')):
      raise RuntimeError("Model not found in %s" % self.model_dir)

    self.batch_size = 1 # We decode one word at a time.
    #Load model parameters.
    num_layers, size = data_utils.load_params(self.model_dir)
    # Load vocabularies
    print("Loading vocabularies from %s" % self.model_dir)
    self.gr_vocab = data_utils.load_vocabulary(os.path.join(self.model_dir,
                                                            "vocab.grapheme"))
    self.ph_vocab = data_utils.load_vocabulary(os.path.join(self.model_dir,
                                                            "vocab.phoneme"))

    self.rev_ph_vocab =\
      data_utils.load_vocabulary(os.path.join(self.model_dir, "vocab.phoneme"),
                                 reverse=True)

    self.session = tf.Session()

    # Restore model.
    print("Creating %d layers of %d units." % (num_layers, size))
    self.model = seq2seq_model.Seq2SeqModel(len(self.gr_vocab),
                                            len(self.ph_vocab), self._BUCKETS,
                                            size, num_layers, 0,
                                            self.batch_size, 0, 0,
                                            forward_only=True)
    self.model.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    # Check for saved models and restore them.
    print("Reading model parameters from %s" % self.model_dir)
    self.model.saver.restore(self.session, os.path.join(self.model_dir,
                                                        "model"))