def __prepare_model(self, params): """Prepare G2P model for training.""" self.params = params self.session = tf.Session() # Prepare model. print("Creating model with parameters:") print(params) if self.mode == 'g2p': self.model = seq2seq_model.Seq2SeqModel(len(self.gr_vocab), len(self.ph_vocab), self._BUCKETS, self.params.size, self.params.num_layers, self.params.max_gradient_norm, self.params.batch_size, self.params.learning_rate, self.params.lr_decay_factor, forward_only=False, optimizer=self.params.optimizer) else: self.model = seq2seq_model.Seq2SeqModel(len(self.ph_vocab), len(self.gr_vocab), self._BUCKETS, self.params.size, self.params.num_layers, self.params.max_gradient_norm, self.params.batch_size, self.params.learning_rate, self.params.lr_decay_factor, forward_only=False, optimizer=self.params.optimizer) self.model.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
def load_decode_model(self): """Load G2P model and initialize or load parameters in session.""" if not os.path.exists(os.path.join(self.model_dir, 'checkpoint')): raise RuntimeError("Model not found in %s" % self.model_dir) self.batch_size = 1 # We decode one word at a time. #Load model parameters. num_layers, size = data_utils.load_params(self.model_dir) # Load vocabularies print("Loading vocabularies from %s" % self.model_dir) self.gr_vocab = data_utils.load_vocabulary(os.path.join(self.model_dir, "vocab.grapheme")) self.ph_vocab = data_utils.load_vocabulary(os.path.join(self.model_dir, "vocab.phoneme")) self.rev_ph_vocab =\ data_utils.load_vocabulary(os.path.join(self.model_dir, "vocab.phoneme"), reverse=True) self.session = tf.Session() # Restore model. print("Creating %d layers of %d units." % (num_layers, size)) self.model = seq2seq_model.Seq2SeqModel(len(self.gr_vocab), len(self.ph_vocab), self._BUCKETS, size, num_layers, 0, self.batch_size, 0, 0, forward_only=True) self.model.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) # Check for saved models and restore them. print("Reading model parameters from %s" % self.model_dir) self.model.saver.restore(self.session, os.path.join(self.model_dir, "model"))