def inference(self, input_ids, input_lengths, speaker_ids): """Call logic.""" # create input-mask based on input_lengths input_mask = tf.sequence_mask( input_lengths, maxlen=tf.reduce_max(input_lengths), name="input_sequence_masks", ) # Encoder Step. encoder_hidden_states = self.encoder( [input_ids, speaker_ids, input_mask], training=False ) batch_size = tf.shape(encoder_hidden_states)[0] alignment_size = tf.shape(encoder_hidden_states)[1] # Setup some initial placeholders for decoder step. Include: # 1. batch_size for inference. # 2. alignment_size for attention size. # 3. initial state for decoder cell. # 4. memory (encoder hidden state) for attention mechanism. # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.) self.decoder.sampler.set_batch_size(batch_size) self.decoder.cell.set_alignment_size(alignment_size) self.decoder.setup_decoder_init_state( self.decoder.cell.get_initial_state(batch_size) ) self.decoder.cell.attention_layer.setup_memory( memory=encoder_hidden_states, memory_sequence_length=input_lengths, # use for mask attention. ) if self.use_window_mask: self.decoder.cell.attention_layer.setup_window( win_front=self.win_front, win_back=self.win_back ) # run decode step. ( (frames_prediction, stop_token_prediction, _), final_decoder_state, _, ) = dynamic_decode(self.decoder, maximum_iterations=self.maximum_iterations) decoder_output = tf.reshape( frames_prediction, [batch_size, -1, self.config.n_mels] ) stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) residual = self.postnet(decoder_output, training=False) residual_projection = self.post_projection(residual) mel_outputs = decoder_output + residual_projection alignment_history = tf.transpose( final_decoder_state.alignment_history.stack(), [1, 2, 0] ) return decoder_output, mel_outputs, stop_token_prediction, alignment_history
def call(self, input_ids, input_lengths, speaker_ids, mel_outputs, mel_lengths, maximum_iterations=tf.constant(2000, tf.int32), use_window_mask=False, win_front=2, win_back=3, training=False): """Call logic.""" # create input-mask based on input_lengths input_mask = tf.sequence_mask(input_lengths, maxlen=tf.reduce_max(input_lengths), name='input_sequence_masks') # Encoder Step. encoder_hidden_states = self.encoder([input_ids, speaker_ids, input_mask], training=training) batch_size = tf.shape(encoder_hidden_states)[0] alignment_size = tf.shape(encoder_hidden_states)[1] # Setup some initial placeholders for decoder step. Include: # 1. mel_outputs, mel_lengths for teacher forcing mode. # 2. alignment_size for attention size. # 3. initial state for decoder cell. # 4. memory (encoder hidden state) for attention mechanism. self.decoder.sampler.setup_target(targets=mel_outputs, mel_lengths=mel_lengths) self.decoder.cell.set_alignment_size(alignment_size) self.decoder.setup_decoder_init_state( self.decoder.cell.get_initial_state(batch_size) ) self.decoder.cell.attention_layer.setup_memory( memory=encoder_hidden_states, memory_sequence_length=input_lengths # use for mask attention. ) if use_window_mask: self.decoder.cell.attention_layer.setup_window(win_front=win_front, win_back=win_back) # run decode step. (frames_prediction, stop_token_prediction, _), final_decoder_state, _ = dynamic_decode( self.decoder, maximum_iterations=maximum_iterations ) decoder_output = tf.reshape(frames_prediction, [batch_size, -1, self.config.n_mels]) stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) residual = self.postnet(decoder_output, training=training) residual_projection = self.post_projection(residual) mel_outputs = decoder_output + residual_projection alignment_history = tf.transpose(final_decoder_state.alignment_history.stack(), [1, 2, 0]) return decoder_output, mel_outputs, stop_token_prediction, alignment_history
def call(self, inputs, training=None, mask=None): dec_emb_fn = lambda ids: self.embed(ids) if self.is_infer: enc_outputs, enc_state, enc_seq_len = inputs batch_size = tf.shape(enc_outputs)[0] helper = seq2seq.GreedyEmbeddingHelper( embedding=dec_emb_fn, start_tokens=tf.fill([batch_size], self.dec_start_id), end_token=self.dec_end_id) else: dec_inputs, dec_seq_len, enc_outputs, enc_state, \ enc_seq_len = inputs batch_size = tf.shape(enc_outputs)[0] dec_inputs = self.embed(dec_inputs) helper = seq2seq.TrainingHelper( inputs=dec_inputs, sequence_length=dec_seq_len) if self.is_infer and self.beam_size > 1: tiled_enc_outputs = seq2seq.tile_batch( enc_outputs, multiplier=self.beam_size) tiled_seq_len = seq2seq.tile_batch(enc_seq_len, multiplier=self.beam_size) attn_mech = self._build_attention( enc_outputs=tiled_enc_outputs, enc_seq_len=tiled_seq_len) dec_cell = seq2seq.AttentionWrapper(self.cell, attn_mech) tiled_enc_last_state = seq2seq.tile_batch( enc_state, multiplier=self.beam_size) tiled_dec_init_state = dec_cell.zero_state( batch_size=batch_size * self.beam_size, dtype=tf.float32) if self.initial_decode_state: tiled_dec_init_state = tiled_dec_init_state.clone( cell_state=tiled_enc_last_state) dec = seq2seq.BeamSearchDecoder( cell=dec_cell, embedding=dec_emb_fn, start_tokens=tf.tile([self.dec_start_id], [batch_size]), end_token=self.dec_end_id, initial_state=tiled_dec_init_state, beam_width=self.beam_size, output_layer=tf.layers.Dense(self.vocab_size), length_penalty_weight=self.length_penalty) else: attn_mech = self._build_attention( enc_outputs=enc_outputs, enc_seq_len=enc_seq_len) dec_cell = seq2seq.AttentionWrapper( cell=self.cell, attention_mechanism=attn_mech) dec_init_state = dec_cell.zero_state( batch_size=batch_size, dtype=tf.float32) if self.initial_decode_state: dec_init_state = dec_init_state.clone(cell_state=enc_state) dec = seq2seq.BasicDecoder( cell=dec_cell, helper=helper, initial_state=dec_init_state, output_layer=tf.layers.Dense(self.vocab_size)) if self.is_infer: dec_outputs, _, _ = \ seq2seq.dynamic_decode(decoder=dec, maximum_iterations=self.max_dec_len, swap_memory=self.swap_memory, output_time_major=self.time_major) return dec_outputs.predicted_ids[:, :, 0] else: dec_outputs, _, _ = \ seq2seq.dynamic_decode(decoder=dec, maximum_iterations=tf.reduce_max(dec_seq_len), swap_memory=self.swap_memory, output_time_major=self.time_major) return dec_outputs.rnn_output