def inference(inputs, input_lengths): """Call logic.""" # Encoder Step. input_lengths = tf.squeeze(input_lengths, -1) if self.mel_layer is not None: inputs = self.mel_layer(inputs) encoder_hidden_states1, encoder_hidden_states2, encoder_hidden_states3 = self.encoder( inputs, training=False) encoder_hidden_states = tf.concat([ encoder_hidden_states1, encoder_hidden_states2, encoder_hidden_states3 ], -1) batch_size = tf.shape(encoder_hidden_states)[0] alignment_size = tf.shape(encoder_hidden_states)[1] ctc3_output = self.fc3(encoder_hidden_states3) phone_decode = tf.keras.backend.ctc_decode( tf.nn.softmax(ctc3_output, -1), input_length=input_lengths)[0][0] # Setup some initial placeholders for decoder step. Include: # 1. batch_size for inference. # 2. alignment_size for attention size. # 3. initial state for decoder cell. # 4. memory (encoder hidden state) for attention mechanism. # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.) self.decoder.sampler.set_batch_size(batch_size) self.decoder.cell.set_alignment_size(alignment_size) # self.setup_maximum_iterations(alignment_size) self.decoder.setup_decoder_init_state( self.decoder.cell.get_initial_state(batch_size)) self.decoder.cell.attention_layer.setup_memory( memory=encoder_hidden_states, memory_sequence_length=input_lengths, # use for mask attention. ) if self.use_window_mask: self.decoder.cell.attention_layer.setup_window( win_front=self.win_front, win_back=self.win_back) ( (classes_prediction, stop_token_prediction, _), final_decoder_state, _, ) = dynamic_decode(self.decoder, maximum_iterations=self.maximum_iterations) bert_output = tf.reshape(classes_prediction, [batch_size, -1, 768]) stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) decoder_output = self.decoder_project(bert_output) decoder_output = self.token_project(decoder_output) final_decoded = self.fc_final(decoder_output) alignment_history = tf.transpose( final_decoder_state.alignment_history.stack(), [1, 2, 0]) final_decoded = tf.argmax(final_decoded, -1) return [final_decoded, phone_decode]
def inference(inputs, input_lengths): """Call logic.""" # Encoder Step. input_lengths = tf.squeeze(input_lengths, -1) if self.wav_info: wav = inputs if self.mel_layer is not None: inputs = self.mel_layer(inputs) if self.wav_info: encoder_hidden_states = self.encoder([inputs, wav], training=False)[-1] else: encoder_hidden_states = self.encoder(inputs, training=False)[-1] batch_size = tf.shape(encoder_hidden_states)[0] alignment_size = tf.shape(encoder_hidden_states)[1] # Setup some initial placeholders for decoder step. Include: # 1. batch_size for inference. # 2. alignment_size for attention size. # 3. initial state for decoder cell. # 4. memory (encoder hidden state) for attention mechanism. # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.) self.decoder.sampler.set_batch_size(batch_size) self.decoder.cell.set_alignment_size(alignment_size) # self.setup_maximum_iterations(alignment_size) self.decoder.setup_decoder_init_state( self.decoder.cell.get_initial_state(batch_size)) self.decoder.cell.attention_layer.setup_memory( memory=encoder_hidden_states, memory_sequence_length=input_lengths, # use for mask attention. ) if self.use_window_mask: self.decoder.cell.attention_layer.setup_window( win_front=self.win_front, win_back=self.win_back) ( (classes_prediction, stop_token_prediction, _), final_decoder_state, _, ) = dynamic_decode(self.decoder, maximum_iterations=self.maximum_iterations) decoder_output = tf.reshape( classes_prediction, [batch_size, -1, self.config.n_classes]) stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) alignment_history = tf.transpose( final_decoder_state.alignment_history.stack(), [1, 2, 0]) decoder_output = tf.argmax(decoder_output, -1) return [decoder_output]
def call( self, inputs, targets=None, targets_lengths=None, use_window_mask=False, win_front=2, win_back=3, training=False, ): """Call logic.""" # Encoder Step. # input_lengths=tf.squeeze(input_lengths,-1) inputs, input_lengths = inputs if self.mel_layer is not None: inputs = self.mel_layer(inputs) encoder_hidden_states = self.encoder(inputs, training=training) batch_size = tf.shape(encoder_hidden_states)[0] alignment_size = tf.shape(encoder_hidden_states)[1] # Setup some initial placeholders for decoder step. Include: # 1. mel_outputs, mel_lengths for teacher forcing mode. # 2. alignment_size for attention size. # 3. initial state for decoder cell. # 4. memory (encoder hidden state) for attention mechanism. if targets is not None: self.decoder.sampler.setup_target(targets=targets, targets_lengths=targets_lengths) self.decoder.sampler.set_batch_size(batch_size) self.decoder.cell.set_alignment_size(alignment_size) self.decoder.setup_decoder_init_state( self.decoder.cell.get_initial_state(batch_size)) self.decoder.cell.attention_layer.setup_memory( memory=encoder_hidden_states, memory_sequence_length=input_lengths, # use for mask attention. ) if use_window_mask: self.decoder.cell.attention_layer.setup_window(win_front=win_front, win_back=win_back) # run decode step. ( (classes_prediction, stop_token_prediction, _), final_decoder_state, _, ) = dynamic_decode( self.decoder, maximum_iterations=self.maximum_iterations, enable_tflite_convertible=self.enable_tflite_convertible) decoder_output = tf.reshape(classes_prediction, [batch_size, -1, self.config.n_classes]) stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) if self.enable_tflite_convertible: mask = tf.math.not_equal( tf.cast(tf.reduce_sum(tf.abs(decoder_output), axis=-1), dtype=tf.int32), 0) decoder_output = tf.expand_dims(tf.boolean_mask( decoder_output, mask), axis=0) alignment_history = () else: alignment_history = tf.transpose( final_decoder_state.alignment_history.stack(), [1, 2, 0]) return decoder_output, stop_token_prediction, alignment_history
def call( self, inputs, input_lengths, targets=None, targets_lengths=None, use_window_mask=False, win_front=2, win_back=3, training=False, ): """Call logic.""" # Encoder Step. # input_lengths=tf.squeeze(input_lengths,-1) encoder_hidden_states1, encoder_hidden_states2, encoder_hidden_states3 = self.encoder( inputs, training=training) ctc1_output = self.fc1(encoder_hidden_states1) ctc2_output = self.fc2(encoder_hidden_states2) ctc3_output = self.fc3(encoder_hidden_states3) encoder_hidden_states = tf.concat([ encoder_hidden_states1, encoder_hidden_states2, encoder_hidden_states3 ], -1) batch_size = tf.shape(encoder_hidden_states)[0] alignment_size = tf.shape(encoder_hidden_states)[1] # Setup some initial placeholders for decoder step. Include: # 1. mel_outputs, mel_lengths for teacher forcing mode. # 2. alignment_size for attention size. # 3. initial state for decoder cell. # 4. memory (encoder hidden state) for attention mechanism. if targets is not None: self.decoder.sampler.setup_target(targets=targets, targets_lengths=targets_lengths) self.decoder.sampler.set_batch_size(batch_size) self.decoder.cell.set_alignment_size(alignment_size) self.decoder.setup_decoder_init_state( self.decoder.cell.get_initial_state(batch_size)) self.decoder.cell.attention_layer.setup_memory( memory=encoder_hidden_states, memory_sequence_length=input_lengths, # use for mask attention. ) if use_window_mask: self.decoder.cell.attention_layer.setup_window(win_front=win_front, win_back=win_back) # run decode step. ( (classes_prediction, stop_token_prediction, _), final_decoder_state, _, ) = dynamic_decode( self.decoder, maximum_iterations=self.maximum_iterations, enable_tflite_convertible=self.enable_tflite_convertible) bert_output = tf.reshape(classes_prediction, [batch_size, -1, 768]) stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) decoder_output = self.decoder_project(bert_output) decoder_output = self.token_project(decoder_output, training=training) final_decoded = self.fc_final(decoder_output) alignment_history = tf.transpose( final_decoder_state.alignment_history.stack(), [1, 2, 0]) return ctc1_output, ctc2_output, ctc3_output, final_decoded, bert_output, stop_token_prediction, alignment_history