Пример #1
0
    def decode_inference(self, memory, states):
        B, _, _ = shape_list(memory)
        # init states
        outputs = tf.TensorArray(dtype=tf.float32,
                                 size=0,
                                 clear_after_read=False,
                                 dynamic_size=True)
        attentions = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    clear_after_read=False,
                                    dynamic_size=True)
        stop_tokens = tf.TensorArray(dtype=tf.float32,
                                     size=0,
                                     clear_after_read=False,
                                     dynamic_size=True)

        # pre-computes
        self.attention.process_values(memory)

        # iter vars
        stop_flag = tf.constant(False, dtype=tf.bool)
        step_count = tf.constant(0, dtype=tf.int32)

        def _body(step, memory, states, outputs, stop_tokens, attentions,
                  stop_flag):
            frame_next = states[0]
            prenet_next = self.prenet(frame_next, training=False)
            output, stop_token, states, attention = self.step(prenet_next,
                                                              states,
                                                              None,
                                                              training=False)
            stop_token = tf.math.sigmoid(stop_token)
            outputs = outputs.write(step, output)
            attentions = attentions.write(step, attention)
            stop_tokens = stop_tokens.write(step, stop_token)
            stop_flag = tf.greater(stop_token, self.stop_thresh)
            stop_flag = tf.reduce_all(stop_flag)
            return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag

        cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(
            stop_flag, tf.constant(False, dtype=tf.bool))
        _, memory, states, outputs, stop_tokens, attentions, stop_flag = \
                tf.while_loop(cond,
                              _body,
                              loop_vars=(step_count, memory, states, outputs,
                                         stop_tokens, attentions, stop_flag),
                              parallel_iterations=32,
                              swap_memory=True,
                              maximum_iterations=self.max_decoder_steps)

        outputs = outputs.stack()
        attentions = attentions.stack()
        stop_tokens = stop_tokens.stack()

        outputs = tf.transpose(outputs, [1, 0, 2])
        attentions = tf.transpose(attentions, [1, 0, 2])
        stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
        stop_tokens = tf.squeeze(stop_tokens, axis=2)
        outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
        return outputs, stop_tokens, attentions
Пример #2
0
    def decode(self, memory, states, frames, memory_seq_length=None):
        B, _, _ = shape_list(memory)
        num_iter = shape_list(frames)[1] // self.r
        # init states
        frame_zero = tf.expand_dims(states[0], 1)
        frames = tf.concat([frame_zero, frames], axis=1)
        outputs = tf.TensorArray(dtype=tf.float32, size=num_iter)
        attentions = tf.TensorArray(dtype=tf.float32, size=num_iter)
        stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter)
        # pre-computes
        self.attention.process_values(memory)
        prenet_output = self.prenet(frames, training=True)
        step_count = tf.constant(0, dtype=tf.int32)

        def _body(step, memory, prenet_output, states, outputs, stop_tokens,
                  attentions):
            prenet_next = prenet_output[:, step]
            output, stop_token, states, attention = self.step(
                prenet_next, states, memory_seq_length)
            outputs = outputs.write(step, output)
            attentions = attentions.write(step, attention)
            stop_tokens = stop_tokens.write(step, stop_token)
            return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions

        _, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop(
            lambda *arg: True,
            _body,
            loop_vars=(step_count, memory, prenet_output, states, outputs,
                       stop_tokens, attentions),
            parallel_iterations=32,
            swap_memory=True,
            maximum_iterations=num_iter,
        )

        outputs = outputs.stack()
        attentions = attentions.stack()
        stop_tokens = stop_tokens.stack()
        outputs = tf.transpose(outputs, [1, 0, 2])
        attentions = tf.transpose(attentions, [1, 0, 2])
        stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
        stop_tokens = tf.squeeze(stop_tokens, axis=2)
        outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
        return outputs, stop_tokens, attentions
Пример #3
0
 def inference(self, characters):
     B, T = shape_list(characters)
     embedding_vectors = self.embedding(characters, training=False)
     encoder_output = self.encoder(embedding_vectors, training=False)
     decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
     decoder_frames, stop_tokens, attentions = self.decoder(encoder_output,
                                                            decoder_states,
                                                            training=False)
     postnet_frames = self.postnet(decoder_frames, training=False)
     output_frames = decoder_frames + postnet_frames
     print(output_frames.shape)
     return decoder_frames, output_frames, attentions, stop_tokens
Пример #4
0
 def training(self, characters, text_lengths, frames):
     B, T = shape_list(characters)
     embedding_vectors = self.embedding(characters, training=True)
     encoder_output = self.encoder(embedding_vectors, training=True)
     decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
     decoder_frames, stop_tokens, attentions = self.decoder(encoder_output,
                                                            decoder_states,
                                                            frames,
                                                            text_lengths,
                                                            training=True)
     postnet_frames = self.postnet(decoder_frames, training=True)
     output_frames = decoder_frames + postnet_frames
     return decoder_frames, output_frames, attentions, stop_tokens