def inference(self, input_ids, input_lengths, speaker_ids, **kwargs): """Call logic.""" # create input-mask based on input_lengths input_mask = tf.sequence_mask( input_lengths, maxlen=tf.reduce_max(input_lengths), name="input_sequence_masks", ) # Encoder Step. encoder_hidden_states = self.encoder( [input_ids, speaker_ids, input_mask], training=False ) batch_size = tf.shape(encoder_hidden_states)[0] alignment_size = tf.shape(encoder_hidden_states)[1] # Setup some initial placeholders for decoder step. Include: # 1. batch_size for inference. # 2. alignment_size for attention size. # 3. initial state for decoder cell. # 4. memory (encoder hidden state) for attention mechanism. # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.) self.decoder.sampler.set_batch_size(batch_size) self.decoder.cell.set_alignment_size(alignment_size) self.decoder.setup_decoder_init_state( self.decoder.cell.get_initial_state(batch_size) ) self.decoder.cell.attention_layer.setup_memory( memory=encoder_hidden_states, memory_sequence_length=input_lengths, # use for mask attention. ) if self.use_window_mask: self.decoder.cell.attention_layer.setup_window( win_front=self.win_front, win_back=self.win_back ) # run decode step. ( (frames_prediction, stop_token_prediction, _), final_decoder_state, _, ) = dynamic_decode(self.decoder, maximum_iterations=self.maximum_iterations) decoder_outputs = tf.reshape( frames_prediction, [batch_size, -1, self.config.n_mels] ) stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1]) residual = self.postnet(decoder_outputs, training=False) residual_projection = self.post_projection(residual) mel_outputs = decoder_outputs + residual_projection alignment_historys = tf.transpose( final_decoder_state.alignment_history.stack(), [1, 2, 0] ) return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys
def call( self, input_ids, input_lengths, speaker_ids, mel_gts, mel_lengths, maximum_iterations=2000, use_window_mask=False, win_front=2, win_back=3, training=False, **kwargs, ): """Call logic.""" # create input-mask based on input_lengths input_mask = tf.sequence_mask( input_lengths, maxlen=tf.reduce_max(input_lengths), name="input_sequence_masks", ) # Encoder Step. encoder_hidden_states = self.encoder( [input_ids, speaker_ids, input_mask], training=training) batch_size = tf.shape(encoder_hidden_states)[0] alignment_size = tf.shape(encoder_hidden_states)[1] # Setup some initial placeholders for decoder step. Include: # 1. mel_gts, mel_lengths for teacher forcing mode. # 2. alignment_size for attention size. # 3. initial state for decoder cell. # 4. memory (encoder hidden state) for attention mechanism. self.decoder.sampler.setup_target(targets=mel_gts, mel_lengths=mel_lengths) self.decoder.cell.set_alignment_size(alignment_size) self.decoder.setup_decoder_init_state( self.decoder.cell.get_initial_state(batch_size)) self.decoder.cell.attention_layer.setup_memory( memory=encoder_hidden_states, memory_sequence_length=input_lengths, # use for mask attention. ) if use_window_mask: self.decoder.cell.attention_layer.setup_window(win_front=win_front, win_back=win_back) # run decode step. ( (frames_prediction, stop_token_prediction, _), final_decoder_state, _, ) = dynamic_decode( self.decoder, maximum_iterations=maximum_iterations, enable_tflite_convertible=self.enable_tflite_convertible, ) decoder_outputs = tf.reshape(frames_prediction, [batch_size, -1, self.config.n_mels]) stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) residual = self.postnet(decoder_outputs, training=training) residual_projection = self.post_projection(residual) mel_outputs = decoder_outputs + residual_projection if self.enable_tflite_convertible: mask = tf.math.not_equal( tf.cast(tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32), 0, ) decoder_outputs = tf.expand_dims(tf.boolean_mask( decoder_outputs, mask), axis=0) mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0) alignment_history = () else: alignment_history = tf.transpose( final_decoder_state.alignment_history.stack(), [1, 2, 0]) return decoder_outputs, mel_outputs, stop_token_prediction, alignment_history