def decode_inference(self, memory, states): B, _, _ = shape_list(memory) # init states outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) # pre-computes self.attention.process_values(memory) # iter vars stop_flag = tf.constant(False, dtype=tf.bool) step_count = tf.constant(0, dtype=tf.int32) def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): frame_next = states[0] prenet_next = self.prenet(frame_next, training=False) output, stop_token, states, attention = self.step(prenet_next, states, None, training=False) stop_token = tf.math.sigmoid(stop_token) outputs = outputs.write(step, output) attentions = attentions.write(step, attention) stop_tokens = stop_tokens.write(step, stop_token) stop_flag = tf.greater(stop_token, self.stop_thresh) stop_flag = tf.reduce_all(stop_flag) return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag cond = lambda step, m, s, o, st, a, stop_flag: tf.equal( stop_flag, tf.constant(False, dtype=tf.bool)) _, memory, states, outputs, stop_tokens, attentions, stop_flag = \ tf.while_loop(cond, _body, loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag), parallel_iterations=32, swap_memory=True, maximum_iterations=self.max_decoder_steps) outputs = outputs.stack() attentions = attentions.stack() stop_tokens = stop_tokens.stack() outputs = tf.transpose(outputs, [1, 0, 2]) attentions = tf.transpose(attentions, [1, 0, 2]) stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) stop_tokens = tf.squeeze(stop_tokens, axis=2) outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) return outputs, stop_tokens, attentions
def decode(self, memory, states, frames, memory_seq_length=None): B, _, _ = shape_list(memory) num_iter = shape_list(frames)[1] // self.r # init states frame_zero = tf.expand_dims(states[0], 1) frames = tf.concat([frame_zero, frames], axis=1) outputs = tf.TensorArray(dtype=tf.float32, size=num_iter) attentions = tf.TensorArray(dtype=tf.float32, size=num_iter) stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter) # pre-computes self.attention.process_values(memory) prenet_output = self.prenet(frames, training=True) step_count = tf.constant(0, dtype=tf.int32) def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): prenet_next = prenet_output[:, step] output, stop_token, states, attention = self.step( prenet_next, states, memory_seq_length) outputs = outputs.write(step, output) attentions = attentions.write(step, attention) stop_tokens = stop_tokens.write(step, stop_token) return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions _, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop( lambda *arg: True, _body, loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions), parallel_iterations=32, swap_memory=True, maximum_iterations=num_iter, ) outputs = outputs.stack() attentions = attentions.stack() stop_tokens = stop_tokens.stack() outputs = tf.transpose(outputs, [1, 0, 2]) attentions = tf.transpose(attentions, [1, 0, 2]) stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) stop_tokens = tf.squeeze(stop_tokens, axis=2) outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) return outputs, stop_tokens, attentions
def inference(self, characters): B, T = shape_list(characters) embedding_vectors = self.embedding(characters, training=False) encoder_output = self.encoder(embedding_vectors, training=False) decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) postnet_frames = self.postnet(decoder_frames, training=False) output_frames = decoder_frames + postnet_frames print(output_frames.shape) return decoder_frames, output_frames, attentions, stop_tokens
def training(self, characters, text_lengths, frames): B, T = shape_list(characters) embedding_vectors = self.embedding(characters, training=True) encoder_output = self.encoder(embedding_vectors, training=True) decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True) postnet_frames = self.postnet(decoder_frames, training=True) output_frames = decoder_frames + postnet_frames return decoder_frames, output_frames, attentions, stop_tokens