コード例 #1
0
ファイル: las_wrap.py プロジェクト: WhiteFu/TensorflowASR
    def __init__(self,
                 config,
                 training,
                 enable_tflite_convertible=True,
                 **kwargs):
        """Init variables."""
        super().__init__(**kwargs)
        self.training = training
        self.enable_tflite_convertible = enable_tflite_convertible
        self.attention_lstm = LayerNormLSTMCell(
            units=config.decoder_lstm_units, name="attention_lstm_cell"
        )
        self.decoder_embedding = tf.keras.layers.Embedding(config.n_classes, config.embedding_hidden_size)
        lstm_cells = []
        for i in range(config.n_lstm_decoder):
            lstm_cell = LayerNormLSTMCell(
                units=config.decoder_lstm_units, name="lstm_cell_._{}".format(i)
            )
            lstm_cells.append(lstm_cell)
        self.decoder_lstms = tf.keras.layers.StackedRNNCells(
            lstm_cells, name="decoder_lstms"
        )
        self.prenet = Prenet(config, name="prenet")
        # define attention layer.

        # create location-sensitive attention.

        self.attention_layer = LocationSensitiveAttention(
            config,
            memory=None,
            mask_encoder=True,
            memory_sequence_length=None,
            is_cumulate=True,
        )
        self.classes_projection = tf.keras.layers.Dense(
            units=config.n_classes, name="classes_projection"
        )
        self.stop_projection = tf.keras.layers.Dense(
            units=1, name="stop_projection"
        )

        self.config = config
コード例 #2
0
 def __init__(self,
              vocabulary_size: int,
              embed_dim: int,
              embed_dropout: float = 0,
              num_lstms: int = 1,
              lstm_units: int = 512,
              name="transducer_prediction",
              **kwargs):
     super(TransducerPrediction, self).__init__(name=name, **kwargs)
     self.embed = tf.keras.layers.Embedding(
         input_dim=vocabulary_size, output_dim=embed_dim, mask_zero=False)
     self.do = tf.keras.layers.Dropout(embed_dropout)
     self.lstm_cells = []
     # lstms units must equal (for using beam search)
     for i in range(num_lstms):
         lstm = LayerNormLSTMCell(units=lstm_units, dropout=embed_dropout, recurrent_dropout=embed_dropout)
         self.lstm_cells.append(lstm)
     self.decoder_lstms = tf.keras.layers.RNN(
         self.lstm_cells, return_sequences=True, return_state=True)
コード例 #3
0
class DecoderCell(tf.keras.layers.AbstractRNNCell):
    """Tacotron-2 custom decoder cell."""
    def __init__(self,
                 config,
                 training,
                 enable_tflite_convertible=True,
                 **kwargs):
        """Init variables."""
        super().__init__(**kwargs)
        self.training = training
        self.enable_tflite_convertible = enable_tflite_convertible
        self.attention_lstm = LayerNormLSTMCell(
            units=config.decoder_lstm_units, name="attention_lstm_cell")
        self.decoder_embedding = tf.keras.layers.Embedding(
            config.n_classes, config.embedding_hidden_size)
        lstm_cells = []
        for i in range(config.n_lstm_decoder):
            lstm_cell = tf.keras.layers.LSTMCell(
                units=config.decoder_lstm_units,
                name="lstm_cell_._{}".format(i))
            lstm_cells.append(lstm_cell)
        self.decoder_lstms = tf.keras.layers.StackedRNNCells(
            lstm_cells, name="decoder_lstms")
        self.prenet = Prenet(config, name="prenet")
        # define attention layer.

        # create location-sensitive attention.

        self.attention_layer = LocationSensitiveAttention(
            config,
            memory=None,
            mask_encoder=True,
            memory_sequence_length=None,
            is_cumulate=True,
        )
        self.classes_projection = tf.keras.layers.Dense(
            units=config.n_classes, name="classes_projection")
        self.stop_projection = tf.keras.layers.Dense(units=1,
                                                     name="stop_projection")

        self.config = config

    def set_alignment_size(self, alignment_size):
        self.alignment_size = alignment_size

    @property
    def output_size(self):
        """Return output (mel) size."""
        return self.classes_projection.units

    @property
    def state_size(self):
        """Return hidden state size."""
        return DecoderCellState(
            attention_lstm_state=self.attention_lstm.state_size,
            decoder_lstms_state=self.decoder_lstms.state_size,
            time=tf.TensorShape([]),
            attention=self.config.attention_dim,
            state=self.alignment_size,
            alignment_history=(),
            max_alignments=tf.TensorShape([1]),
        )

    def get_initial_state(self, batch_size):
        """Get initial states."""
        initial_attention_lstm_cell_states = self.attention_lstm.get_initial_state(
            None, batch_size, dtype=tf.float32)
        initial_decoder_lstms_cell_states = self.decoder_lstms.get_initial_state(
            None, batch_size, dtype=tf.float32)
        initial_context = tf.zeros(shape=[batch_size, self.config.encoder_dim],
                                   dtype=tf.float32)
        initial_state = self.attention_layer.get_initial_state(
            batch_size, size=self.alignment_size)
        if self.enable_tflite_convertible:
            initial_alignment_history = ()
        else:
            initial_alignment_history = tf.TensorArray(dtype=tf.float32,
                                                       size=0,
                                                       dynamic_size=True)
        return DecoderCellState(
            attention_lstm_state=initial_attention_lstm_cell_states,
            decoder_lstms_state=initial_decoder_lstms_cell_states,
            time=tf.zeros([], dtype=tf.int32),
            context=initial_context,
            state=initial_state,
            alignment_history=initial_alignment_history,
            max_alignments=tf.zeros([batch_size], dtype=tf.int32),
        )

    def call(self, inputs, states):
        """Call logic."""
        # tf.print(inputs.shape)
        decoder_input = self.decoder_embedding(inputs)[:, 0, :]

        # 1. apply prenet for decoder_input.
        prenet_out = self.prenet(decoder_input,
                                 training=self.training)  # [batch_size, dim]

        # 2. concat prenet_out and prev context vector
        # then use it as input of attention lstm layer.
        attention_lstm_input = tf.concat([prenet_out, states.context], axis=-1)
        attention_lstm_output, next_attention_lstm_state = self.attention_lstm(
            attention_lstm_input, states.attention_lstm_state)

        # 3. compute context, alignment and cumulative alignment.
        prev_state = states.state
        if not self.enable_tflite_convertible:
            prev_alignment_history = states.alignment_history
        prev_max_alignments = states.max_alignments
        context, alignments, state = self.attention_layer(
            [attention_lstm_output, prev_state, prev_max_alignments],
            training=self.training,
        )

        # 4. run decoder lstm(s)
        decoder_lstms_input = tf.concat([attention_lstm_output, context],
                                        axis=-1)
        decoder_lstms_output, next_decoder_lstms_state = self.decoder_lstms(
            decoder_lstms_input, states.decoder_lstms_state)

        # 5. compute frame feature and stop token.
        projection_inputs = tf.concat([decoder_lstms_output, context], axis=-1)
        decoder_outputs = self.classes_projection(projection_inputs)

        stop_inputs = tf.concat([decoder_lstms_output, decoder_outputs],
                                axis=-1)
        stop_tokens = self.stop_projection(stop_inputs)

        # 6. save alignment history to visualize.
        if self.enable_tflite_convertible:
            alignment_history = ()
        else:
            alignment_history = prev_alignment_history.write(
                states.time, alignments)

        # 7. return new states.
        new_states = DecoderCellState(
            attention_lstm_state=next_attention_lstm_state,
            decoder_lstms_state=next_decoder_lstms_state,
            time=states.time + 1,
            context=context,
            state=state,
            alignment_history=alignment_history,
            max_alignments=tf.argmax(alignments, -1, output_type=tf.int32),
        )

        return (decoder_outputs, stop_tokens), new_states