Beispiel #1
0
def reference_encoder(inputs,
                      filters,
                      kernel_size,
                      strides,
                      encoder_cell,
                      is_training,
                      scope='ref_encoder'):
    with tf.variable_scope(scope):
        ref_outputs = tf.expand_dims(inputs, axis=-1)
        # CNN stack
        for i, channel in enumerate(filters):
            ref_outputs = conv2d(ref_outputs, channel, kernel_size, strides,
                                 tf.nn.relu, is_training, 'conv2d_%d' % i)

        shapes = shape_list(ref_outputs)
        ref_outputs = tf.reshape(ref_outputs,
                                 shapes[:-2] + [shapes[2] * shapes[3]])
        # RNN
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell,
                                                           ref_outputs,
                                                           dtype=tf.float32)

        reference_state = tf.layers.dense(encoder_outputs[:, -1, :],
                                          128,
                                          activation=tf.nn.tanh)  # [N, 128]
        return reference_state
Beispiel #2
0
    def __init__(self, inputs, is_training, hparams=None, scope='emt_disc'):

        self._hparams = hparams
        filters = [32, 32, 64, 64, 128, 128]
        kernel_size = (3, 3)
        strides = (2, 2)

        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):

            encoder_cell = GRUCell(128)
            ref_outputs = tf.expand_dims(inputs, axis=-1)

            # CNN stack
            for i, channel in enumerate(filters):
                ref_outputs = conv2d(ref_outputs, channel, kernel_size,
                                     strides, tf.nn.relu, is_training,
                                     'conv2d_%d' % i)

            shapes = shape_list(ref_outputs)

            ref_outputs = tf.reshape(ref_outputs,
                                     shapes[:-2] + [shapes[2] * shapes[3]])
            # RNN
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                encoder_cell, ref_outputs, dtype=tf.float32)

            emb = tf.layers.dense(encoder_outputs[:, -1, :],
                                  128,
                                  activation=tf.nn.tanh)  # [N, 128]
            self.logit = tf.layers.dense(emb, 4)
            self.emb = tf.expand_dims(emb, axis=1)  # [N,1,128]
Beispiel #3
0
 def transpose_then_concat_last_two_dimension(tensor):
     tensor = tf.transpose(
         tensor,
         [0, 2, 1, 3])  # [batch_size, max_seq_len, num_heads, dim]
     t_shape = shape_list(tensor)
     num_heads, dim = t_shape[-2:]
     return tf.reshape(tensor, t_shape[:-2] + [num_heads * dim])
Beispiel #4
0
def reference_encoder(inputs,
                      filters,
                      kernel_size,
                      strides,
                      is_training,
                      scope="reference_encoder"):
    """
    Use 6 x 2-D convolution layers and A single GRU
    """
    with tf.variable_scope(scope):
        # inputs: N x T x M x 1, output N x T x M x C
        outputs = tf.expand_dims(inputs, axis=-1)  # for 2-D conv
        for i, channels in enumerate(filters):
            outputs = conv2d(outputs, kernel_size, channels, strides,
                             tf.nn.relu, is_training, 're_conv2d_%d' % i)

        # reshape to 3 dimension and preserving time resolution

        shapes = shape_list(outputs)
        outputs = tf.reshape(outputs,
                             [shapes[0], shapes[1], shapes[2] * shapes[3]])

        # apply a single rnn layer
        outputs, states = tf.nn.dynamic_rnn(GRUCell(128),
                                            outputs,
                                            dtype=tf.float32)
        # the last state serves as the reference encoder embedding
        return tf.nn.tanh(outputs), states
 def _combine_heads(self, x):
     '''Combine all heads
    Returns:
        a Tensor with shape [batch, length_x, shape_x[-1] * shape_x[-3]]
 '''
     x = tf.transpose(x, [0, 2, 1, 3])
     x_shape = shape_list(x)
     return tf.reshape(x, x_shape[:-2] + [self.num_heads * x_shape[-1]])
Beispiel #6
0
 def split_last_dimension_then_transpose(tensor, num_heads):
     t_shape = shape_list(tensor)
     dim = t_shape[-1]
     assert dim % num_heads == 0
     tensor = tf.reshape(tensor,
                         t_shape[:-1] + [num_heads, dim // num_heads])
     return tf.transpose(
         tensor,
         [0, 2, 1, 3])  # [batch_size, num_heads, max_seq_len, dim]
 def _split_last_dimension(self, x, num_heads):
     '''Reshape x to num_heads
 Returns:
     a Tensor with shape [batch, length_x, num_heads, dim_x/num_heads]
 '''
     x_shape = shape_list(x)
     dim = x_shape[-1]
     assert dim % num_heads == 0
     return tf.reshape(x, x_shape[:-1] + [num_heads, dim // num_heads])
 def _split_heads(self, q, k, v):
     '''Split the channels into multiple heads
 
 Returns:
      Tensors with shape [batch, num_heads, length_x, dim_x/num_heads]
 '''
     qs = tf.transpose(self._split_last_dimension(q, self.num_heads),
                       [0, 2, 1, 3])
     ks = tf.transpose(self._split_last_dimension(k, self.num_heads),
                       [0, 2, 1, 3])
     v_shape = shape_list(v)
     vs = tf.tile(tf.expand_dims(v, axis=1), [1, self.num_heads, 1, 1])
     return qs, ks, vs
def discriminator(inputs_content,
                  inputs_mel,
                  is_training,
                  scope='discriminator'):

    filters = [32, 32, 64, 64, 128, 128]
    kernel_size = (3, 3)
    strides = (2, 2)
    encoder_cell = GRUCell(128)

    inputs_content = tf.reduce_mean(inputs_content, 1)
    inputs_content = tf.tile(inputs_content, [1, shape_list(inputs_mel)[1], 1])

    inputs = tf.concat([inputs_mel, inputs_content], axis=-1)

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        ref_outputs = tf.expand_dims(inputs, axis=-1)

        # CNN stack
        for i, channel in enumerate(filters):
            ref_outputs = conv2d(ref_outputs, channel, kernel_size, strides,
                                 tf.nn.relu, is_training, 'conv2d_%d' % i)

        shapes = shape_list(ref_outputs)

        ref_outputs = tf.reshape(ref_outputs,
                                 shapes[:-2] + [shapes[2] * shapes[3]])
        # RNN
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell,
                                                           ref_outputs,
                                                           dtype=tf.float32)

        reference_state = tf.layers.dense(encoder_outputs[:, -1, :],
                                          128,
                                          activation=tf.nn.tanh)  # [N, 128]
        logit = tf.layers.dense(reference_state, 3)
        return logit
Beispiel #10
0
  def initialize(self, inputs, input_lengths, inputs_jp=None, mel_targets=None, linear_targets=None ):
    '''Initializes the model for inference.

    Sets "mel_outputs", "linear_outputs", and "alignments" fields.

    Args:
      inputs: int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of
        steps in the input time series, and values are character IDs
      input_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths
        of each sequence in inputs.
      mel_targets: float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number
        of steps in the output time series, M is num_mels, and values are entries in the mel
        spectrogram. Only needed for training.
      linear_targets: float32 Tensor with shape [N, T_out, F] where N is batch_size, T_out is number
        of steps in the output time series, F is num_freq, and values are entries in the linear
        spectrogram. Only needed for training.
    '''
    with tf.variable_scope('inference') as scope:
      is_training = linear_targets is not None
      is_teacher_force_generating = mel_targets is not None
      batch_size = tf.shape(inputs)[0]
      hp = self._hparams

      # Embeddings
      # embedding_table = tf.get_variable(
      #   'text_embedding', [len(symbols), hp.embed_depth], dtype=tf.float32,
      #   initializer=tf.truncated_normal_initializer(stddev=0.5))
      # embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs)           # [N, T_in, 256]
      
      if hp.use_gst:
        #Global style tokens (GST)
        gst_tokens = tf.get_variable(
          'style_tokens', [hp.num_gst, hp.style_embed_depth // hp.num_heads], dtype=tf.float32,
          initializer=tf.truncated_normal_initializer(stddev=0.5))
        self.gst_tokens = gst_tokens
 
      # Encoder
      # prenet_outputs = prenet(embedded_inputs, is_training)
      prenet_outputs = prenet(inputs, is_training)
      # [N, T_in, 128]
      encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths, is_training)  # [N, T_in, 256]
      


      if inputs_jp  is not None:
        # Reference encoder
        refnet_outputs = reference_encoder(
          inputs_jp,
          filters=hp.reference_filters, 
          kernel_size=(3,3),
          strides=(2,2),
          encoder_cell=GRUCell(hp.reference_depth),
          is_training=is_training)                                                 # [N, 128]
        self.refnet_outputs = refnet_outputs                                       

        if hp.use_gst:
          # Style attention
          style_attention = MultiheadAttention(
            tf.expand_dims(refnet_outputs, axis=1),                                   # [N, 1, 128]
            tf.tanh(tf.tile(tf.expand_dims(gst_tokens, axis=0), [batch_size,1,1])),            # [N, hp.num_gst, 256/hp.num_heads]   
            num_heads=hp.num_heads,
            num_units=hp.style_att_dim,
            attention_type=hp.style_att_type)

          style_embeddings = style_attention.multi_head_attention()                   # [N, 1, 256]
        else:
          style_embeddings = tf.expand_dims(refnet_outputs, axis=1)                   # [N, 1, 128]
      else:
        print("Use random weight for GST.")
        random_weights = tf.random_uniform([hp.num_heads, hp.num_gst], maxval=1.0, dtype=tf.float32)
        random_weights = tf.nn.softmax(random_weights, name="random_weights")
        style_embeddings = tf.matmul(random_weights, tf.nn.tanh(gst_tokens))
        style_embeddings = tf.reshape(style_embeddings, [1, 1] + [hp.num_heads * gst_tokens.get_shape().as_list()[1]])

      # Add style embedding to every text encoder state
      style_embeddings = tf.tile(style_embeddings, [1, shape_list(encoder_outputs)[1], 1]) # [N, T_in, 128]
      encoder_outputs = tf.concat([encoder_outputs, style_embeddings], axis=-1)

      # Attention
      attention_cell = AttentionWrapper(
        GRUCell(hp.attention_depth),
        BahdanauAttention(hp.attention_depth, encoder_outputs, memory_sequence_length=input_lengths),
        alignment_history=True,
        output_attention=False)                                                  # [N, T_in, 256]

      # Concatenate attention context vector and RNN cell output.
      concat_cell = ConcatOutputAndAttentionWrapper(attention_cell)              

      # Decoder (layers specified bottom to top):
      decoder_cell = MultiRNNCell([
          OutputProjectionWrapper(concat_cell, hp.rnn_depth),
          ResidualWrapper(ZoneoutWrapper(LSTMCell(hp.rnn_depth), 0.1)),
          ResidualWrapper(ZoneoutWrapper(LSTMCell(hp.rnn_depth), 0.1))
        ], state_is_tuple=True)                                                  # [N, T_in, 256]

      # Project onto r mel spectrograms (predict r outputs at each RNN step):
      output_cell = OutputProjectionWrapper(decoder_cell, hp.num_mels * hp.outputs_per_step)
      decoder_init_state = output_cell.zero_state(batch_size=batch_size, dtype=tf.float32)

      if is_training or is_teacher_force_generating:
        helper = TacoTrainingHelper(inputs, mel_targets, hp)
      else:
        helper = TacoTestHelper(batch_size, hp)

      (decoder_outputs, _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
        BasicDecoder(output_cell, helper, decoder_init_state),
        maximum_iterations=hp.max_iters)                                        # [N, T_out/r, M*r]

      # Reshape outputs to be one output per entry
      mel_outputs = tf.reshape(decoder_outputs, [batch_size, -1, hp.num_mels]) # [N, T_out, M]

      # Add post-processing CBHG:
      post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training)           # [N, T_out, 256]
      linear_outputs = tf.layers.dense(post_outputs, hp.num_freq)               # [N, T_out, F]

      # Grab alignments from the final decoder state:
      alignments = tf.transpose(final_decoder_state[0].alignment_history.stack(), [1, 2, 0])

      self.inputs = inputs
      self.input_lengths = input_lengths
      self.mel_outputs = mel_outputs
      self.encoder_outputs = encoder_outputs
      self.style_embeddings = style_embeddings
      self.linear_outputs = linear_outputs
      self.alignments = alignments
      self.mel_targets = mel_targets
      self.linear_targets = linear_targets
      self.inputs_jp = inputs_jp
      log('Initialized Tacotron model. Dimensions: ')
      log('  style embedding:         %d' % style_embeddings.shape[-1])
      log('  prenet out:              %d' % prenet_outputs.shape[-1])
      log('  encoder out:             %d' % encoder_outputs.shape[-1])
      log('  attention out:           %d' % attention_cell.output_size)
      log('  concat attn & out:       %d' % concat_cell.output_size)
      log('  decoder cell out:        %d' % decoder_cell.output_size)
      log('  decoder out (%d frames):  %d' % (hp.outputs_per_step, decoder_outputs.shape[-1]))
      log('  decoder out (1 frame):   %d' % mel_outputs.shape[-1])
      log('  postnet out:             %d' % post_outputs.shape[-1])
      log('  linear out:              %d' % linear_outputs.shape[-1])
Beispiel #11
0
    def initialize(self,
                   inputs,
                   input_lengths,
                   mel_targets=None,
                   linear_targets=None,
                   reference_mel=None):
        '''Initializes the model for inference.

    Sets "mel_outputs", "linear_outputs", and "alignments" fields.

    Args:
      inputs: int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of
        steps in the input time series, and values are character IDs
      input_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths
        of each sequence in inputs.
      mel_targets: float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number
        of steps in the output time series, M is num_mels, and values are entries in the mel
        spectrogram. Only needed for training.
      linear_targets: float32 Tensor with shape [N, T_out, F] where N is batch_size, T_out is number
        of steps in the output time series, F is num_freq, and values are entries in the linear
        spectrogram. Only needed for training.
    '''
        with tf.variable_scope('inference') as scope:
            is_training = linear_targets is not None
            batch_size = tf.shape(inputs)[0]
            hp = self._hparams

            # Embeddings
            embedding_table = tf.get_variable(
                'text_embedding', [len(symbols), hp.embed_depth],
                dtype=tf.float32,
                initializer=tf.truncated_normal_initializer(stddev=0.5))
            embedded_inputs = tf.nn.embedding_lookup(embedding_table,
                                                     inputs)  # [N, T_in, 256]

            #Global style tokens (GST)
            gst_tokens = tf.get_variable(
                'style_tokens',
                [hp.num_gst, hp.style_embed_depth // hp.num_heads],
                dtype=tf.float32,
                initializer=tf.truncated_normal_initializer(stddev=0.5))
            self.gst_tokens = gst_tokens

            # Encoder

            encoder_outputs = encoder(embedded_inputs, input_lengths,
                                      is_training, 512, 5,
                                      256)  # [N, T_in, 256]

            if is_training:
                reference_mel = mel_targets

            if reference_mel is not None:
                # Reference encoder
                refnet_outputs = reference_encoder(
                    reference_mel,
                    filters=hp.ref_filters,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    encoder_cell=GRUCell(hp.ref_depth),
                    is_training=is_training)  # [N, 128]
                self.refnet_outputs = refnet_outputs

                # Style attention
                style_attention = MultiheadAttention(
                    tf.expand_dims(refnet_outputs, axis=1),  # [N, 1, 128]
                    tf.tanh(
                        tf.tile(tf.expand_dims(gst_tokens, axis=0),
                                [batch_size, 1, 1
                                 ])),  # [N, hp.num_gst, 256/hp.num_heads]   
                    num_heads=hp.num_heads,
                    num_units=hp.style_att_dim,
                    attention_type=hp.style_att_type)

                embedded_tokens = style_attention.multi_head_attention(
                )  # [N, 1, 256]

            else:
                random_weights = tf.constant(
                    hp.num_heads * [[0] * (hp.gst_index - 1) + [1] + [0] *
                                    (hp.num_gst - hp.gst_index)],
                    dtype=tf.float32)
                random_weights = tf.nn.softmax(random_weights,
                                               name="random_weights")
                # gst_tokens = tf.tile(gst_tokens, [1, hp.num_heads])
                embedded_tokens = tf.matmul(random_weights,
                                            tf.nn.tanh(gst_tokens))
                embedded_tokens = hp.gst_scale * embedded_tokens
                embedded_tokens = tf.reshape(
                    embedded_tokens, [1, 1] +
                    [hp.num_heads * gst_tokens.get_shape().as_list()[1]])

            # Add style embedding to every text encoder state
            style_embeddings = tf.tile(
                embedded_tokens,
                [1, shape_list(encoder_outputs)[1], 1])  # [N, T_in, 128]
            encoder_outputs = tf.concat([encoder_outputs, style_embeddings],
                                        axis=-1)

            # Attention
            attention_mechanism = LocationSensitiveAttention(
                128,
                encoder_outputs,
                hparams=hp,
                is_training=is_training,
                mask_encoder=True,
                memory_sequence_length=input_lengths,
                smoothing=False,
                cumulate_weights=True)
            decoder_lstm = [
                ZoneoutLSTMCell(1024,
                                is_training,
                                zoneout_factor_cell=0.1,
                                zoneout_factor_output=0.1,
                                name='decoder_LSTM_{}'.format(i + 1))
                for i in range(2)
            ]

            decoder_lstm = MultiRNNCell(decoder_lstm, state_is_tuple=True)
            decoder_init_state = decoder_lstm.zero_state(
                batch_size=batch_size, dtype=tf.float32)  #tensorflow1에는 없음

            attention_cell = AttentionWrapper(
                decoder_lstm,
                attention_mechanism,
                initial_cell_state=decoder_init_state,
                alignment_history=True,
                output_attention=False)

            # attention_state_size = 256
            # Decoder input -> prenet -> decoder_lstm -> concat[output, attention]
            # dec_outputs = DecoderPrenetWrapper(attention_cell, is_training, hp.prenet_depths)
            dec_outputs_cell = OutputProjectionWrapper(
                attention_cell, (hp.num_mels) * hp.outputs_per_step)

            if is_training:
                helper = TacoTrainingHelper(inputs, mel_targets, hp)
            else:
                helper = TacoTestHelper(batch_size, hp)

            decoder_init_state = dec_outputs_cell.zero_state(
                batch_size=batch_size, dtype=tf.float32)
            (decoder_outputs,
             _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
                 BasicDecoder(dec_outputs_cell, helper, decoder_init_state),
                 maximum_iterations=hp.max_iters)  # [N, T_out/r, M*r]

            # Reshape outputs to be one output per entry
            decoder_mel_outputs = tf.reshape(
                decoder_outputs[:, :, :hp.num_mels * hp.outputs_per_step],
                [batch_size, -1, hp.num_mels])  # [N, T_out, M]

            x = decoder_mel_outputs
            for i in range(5):
                activation = tf.nn.tanh if i != (4) else None
                x = tf.layers.conv1d(x,
                                     filters=512,
                                     kernel_size=5,
                                     padding='same',
                                     activation=activation,
                                     name='Postnet_{}'.format(i))
                x = tf.layers.batch_normalization(x, training=is_training)
                x = tf.layers.dropout(x,
                                      rate=0.5,
                                      training=is_training,
                                      name='Postnet_dropout_{}'.format(i))

            residual = tf.layers.dense(x,
                                       hp.num_mels,
                                       name='residual_projection')
            mel_outputs = decoder_mel_outputs + residual

            # Add post-processing CBHG:
            # mel_outputs: (N,T,num_mels)
            post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training)
            linear_outputs = tf.layers.dense(
                post_outputs,
                hp.num_freq)  # [N, T_out, F(1025)]             # [N, T_out, F]

            # Grab alignments from the final decoder state:
            alignments = tf.transpose(
                final_decoder_state.alignment_history.stack(), [1, 2, 0])

            self.inputs = inputs
            self.input_lengths = input_lengths
            self.decoder_mel_outputs = decoder_mel_outputs
            self.mel_outputs = mel_outputs
            self.encoder_outputs = encoder_outputs
            self.style_embeddings = style_embeddings
            self.linear_outputs = linear_outputs
            self.alignments = alignments
            self.mel_targets = mel_targets
            self.linear_targets = linear_targets
            self.reference_mel = reference_mel
            self.all_vars = tf.trainable_variables()
            log('Initialized Tacotron model. Dimensions: ')
            log('  text embedding:          %d' % embedded_inputs.shape[-1])
            log('  style embedding:         %d' % style_embeddings.shape[-1])
            # log('  prenet out:              %d' % prenet_outputs.shape[-1])
            log('  encoder out:             %d' % encoder_outputs.shape[-1])
            log('  attention out:           %d' % attention_cell.output_size)
            # log('  concat attn & out:       %d' % concat_cell.output_size)
            log('  decoder cell out:        %d' % dec_outputs_cell.output_size)
            log('  decoder out (%d frames):  %d' %
                (hp.outputs_per_step, decoder_outputs.shape[-1]))
            log('  decoder out (1 frame):   %d' % mel_outputs.shape[-1])
            log('  postnet out:             %d' % post_outputs.shape[-1])
            log('  linear out:              %d' % linear_outputs.shape[-1])
Beispiel #12
0
    def initialize(self,
                   inputs,
                   input_lengths,
                   mel_targets_pos=None,
                   linear_targets_pos=None,
                   mel_targets_neg=None,
                   linear_targets_neg=None,
                   labels_pos=None,
                   labels_neg=None,
                   reference_mel_pos=None,
                   reference_mel_neg=None):

        is_training = linear_targets_pos is not None
        is_teacher_force_generating = mel_targets_pos is not None
        batch_size = tf.shape(inputs)[0]
        hp = self._hparams

        ## Text Encoding scope
        with tf.variable_scope('text_encoder', reuse=tf.AUTO_REUSE) as scope:
            # Initialize Text Embeddings
            embedding_table = tf.get_variable(
                'text_embedding', [len(symbols), 256],
                dtype=tf.float32,
                initializer=tf.truncated_normal_initializer(stddev=0.5))
            embedded_inputs = tf.nn.embedding_lookup(embedding_table,
                                                     inputs)  # [N, T_in, 256]

            # Text Encoder
            prenet_outputs = prenet(embedded_inputs,
                                    is_training)  # [N, T_in, 128]
            encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths,
                                           is_training)  # [N, T_in, 256]

            content_inputs = encoder_outputs

        ## Reference Encoding Scope
        with tf.variable_scope('audio_encoder', reuse=tf.AUTO_REUSE) as scope:

            if hp.use_gst:
                #Global style tokens (GST)
                gst_tokens = tf.get_variable(
                    'style_tokens', [hp.num_gst, 256 // hp.num_heads],
                    dtype=tf.float32,
                    initializer=tf.truncated_normal_initializer(stddev=0.5))
                self.gst_tokens = gst_tokens

            if is_training:

                reference_mel_pos = mel_targets_pos
                reference_mel_neg = mel_targets_neg

            if reference_mel_pos is not None:
                # Reference encoder
                refnet_outputs_pos = reference_encoder(
                    reference_mel_pos,
                    filters=[32, 32, 64, 64, 128, 128],
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    encoder_cell=GRUCell(128),
                    is_training=is_training)  # [n, 128]
                self.refnet_outputs_pos = refnet_outputs_pos

                refnet_outputs_neg = reference_encoder(
                    reference_mel_neg,
                    filters=[32, 32, 64, 64, 128, 128],
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    encoder_cell=GRUCell(128),
                    is_training=is_training)  # [n, 128]
                self.refnet_outputs_neg = refnet_outputs_neg
                # Extract style features
                ref_style = style_encoder(reference_mel_neg,
                                          filters=[32, 32, 64, 64],
                                          kernel_size=(3, 3),
                                          strides=(2, 2),
                                          is_training=False)
                self.ref_style = ref_style

                if hp.use_gst:
                    # Multi-head attention
                    style_attention_pos = MultiheadAttention(
                        tf.tanh(tf.expand_dims(refnet_outputs_pos,
                                               axis=1)),  # [N, 1, 128]
                        tf.tile(tf.expand_dims(gst_tokens, axis=0),
                                [batch_size, 1, 1
                                 ]),  # [N, hp.num_gst, 256/hp.num_heads]   
                        num_heads=hp.num_heads,
                        num_units=128,
                        attention_type=hp.style_att_type)

                    style_attention_neg = MultiheadAttention(
                        tf.tanh(tf.expand_dims(refnet_outputs_neg,
                                               axis=1)),  # [N, 1, 128]
                        tf.tile(tf.expand_dims(gst_tokens, axis=0),
                                [batch_size, 1, 1
                                 ]),  # [N, hp.num_gst, 256/hp.num_heads]   
                        num_heads=hp.num_heads,
                        num_units=128,
                        attention_type=hp.style_att_type)

                    # Apply tanh to compress both encoder state and style embedding to the same scale.

                    style_embeddings_pos = style_attention_pos.multi_head_attention(
                    )  # [N, 1, 256]
                    style_embeddings_neg = style_attention_neg.multi_head_attention(
                    )  # [N, 1, 256]

                else:
                    style_embeddings_pos = tf.expand_dims(
                        refnet_outputs_pos, axis=1)  # [N, 1, 128]
                    style_embeddings_neg = tf.expand_dims(refnet_outputs_neg,
                                                          axis=1)
            else:
                print("Use random weight for GST.")

            # Add style embedding to every text encoder state
            ## tile style embeddings such that it could matched with text sequence shape,
            ## format: _content_style
            style_embeddings_pos = tf.tile(
                style_embeddings_pos,
                [1, shape_list(encoder_outputs)[1], 1])  # [N, T_in, 128]
            style_embeddings_neg = tf.tile(
                style_embeddings_neg,
                [1, shape_list(encoder_outputs)[1], 1])  # [N, T_in, 128]
            ## purmute four encoder outputs, e.g. pos2pos is positive content wieh positive style, pos2neg is postive content wity
            ## negtive style.
            encoder_outputs_pos = tf.concat(
                [encoder_outputs, style_embeddings_pos], axis=-1)
            encoder_outputs_neg = tf.concat(
                [encoder_outputs, style_embeddings_neg], axis=-1)

        # Decoding scope
        with tf.variable_scope('generator', reuse=tf.AUTO_REUSE) as scope:
            # RNN Attention
            attention_cell_pos = AttentionWrapper(
                DecoderPrenetWrapper(GRUCell(256), is_training),
                BahdanauAttention(256,
                                  encoder_outputs_pos,
                                  memory_sequence_length=input_lengths),
                alignment_history=True,
                output_attention=False)  # [N, T_in, 256]

            attention_cell_neg = AttentionWrapper(
                DecoderPrenetWrapper(GRUCell(256), is_training),
                BahdanauAttention(256,
                                  encoder_outputs_neg,
                                  memory_sequence_length=input_lengths),
                alignment_history=True,
                output_attention=False)  # [N, T_in, 256]

            # Concatenate attention context vector and RNN cell output.
            concat_cell_pos = ConcatOutputAndAttentionWrapper(
                attention_cell_pos)
            concat_cell_neg = ConcatOutputAndAttentionWrapper(
                attention_cell_neg)

            # Decoder (layers specified bottom to top):
            decoder_cell_pos = MultiRNNCell(
                [
                    OutputProjectionWrapper(concat_cell_pos, 256),
                    ResidualWrapper(ZoneoutWrapper(LSTMCell(256), 0.1)),
                    ResidualWrapper(ZoneoutWrapper(LSTMCell(256), 0.1))
                ],
                state_is_tuple=True)  # [N, T_in, 256]

            decoder_cell_neg = MultiRNNCell(
                [
                    OutputProjectionWrapper(concat_cell_neg, 256),
                    ResidualWrapper(ZoneoutWrapper(LSTMCell(256), 0.1)),
                    ResidualWrapper(ZoneoutWrapper(LSTMCell(256), 0.1))
                ],
                state_is_tuple=True)  # [N, T_in, 256]

            # Project onto r mel spectrograms (predict r outputs at each RNN step):
            output_cell_pos = OutputProjectionWrapper(
                decoder_cell_pos, hp.num_mels * hp.outputs_per_step)
            decoder_init_state_pos = output_cell_pos.zero_state(
                batch_size=batch_size, dtype=tf.float32)

            output_cell_neg = OutputProjectionWrapper(
                decoder_cell_neg, hp.num_mels * hp.outputs_per_step)
            decoder_init_state_neg = output_cell_neg.zero_state(
                batch_size=batch_size, dtype=tf.float32)

            if is_training or is_teacher_force_generating:
                helper_pos = TacoTrainingHelper(inputs, mel_targets_pos,
                                                hp.num_mels,
                                                hp.outputs_per_step)
                helper_neg = TacoTrainingHelper(inputs, mel_targets_neg,
                                                hp.num_mels,
                                                hp.outputs_per_step)

            else:
                helper = TacoTestHelper(batch_size, hp.num_mels,
                                        hp.outputs_per_step)

            (decoder_outputs_pos, _
             ), final_decoder_state_pos, _ = tf.contrib.seq2seq.dynamic_decode(
                 BasicDecoder(output_cell_pos, helper_pos,
                              decoder_init_state_pos),
                 maximum_iterations=hp.max_iters)  # [N, T_out/r, M*r]

            (decoder_outputs_neg, _
             ), final_decoder_state_neg, _ = tf.contrib.seq2seq.dynamic_decode(
                 BasicDecoder(output_cell_neg, helper_neg,
                              decoder_init_state_neg),
                 maximum_iterations=hp.max_iters)  # [N, T_out/r, M*r]

            # Reshape outputs to be one output per entry

            mel_outputs_pos = tf.reshape(
                decoder_outputs_pos,
                [batch_size, -1, hp.num_mels])  # [N, T_out, M]
            mel_outputs_neg = tf.reshape(
                decoder_outputs_neg,
                [batch_size, -1, hp.num_mels])  # [N, T_out, M]

            # Add post-processing CBHG:
            post_outputs_pos = post_cbhg(mel_outputs_pos, hp.num_mels,
                                         is_training)  # [N, T_out, 256]
            linear_outputs_pos = tf.layers.dense(post_outputs_pos,
                                                 hp.num_freq)  # [N, T_out, F]

            post_outputs_neg = post_cbhg(mel_outputs_neg, hp.num_mels,
                                         is_training)  # [N, T_out, 256]
            linear_outputs_neg = tf.layers.dense(post_outputs_neg,
                                                 hp.num_freq)  # [N, T_out, F]

            ## Grab alignments from the final decoder state:
            alignments_pos = tf.transpose(
                final_decoder_state_pos[0].alignment_history.stack(),
                [1, 2, 0])
            alignments_neg = tf.transpose(
                final_decoder_state_neg[0].alignment_history.stack(),
                [1, 2, 0])

            # Extract style features for fake sample
            rec_style = style_encoder(mel_outputs_neg,
                                      filters=[32, 32, 64, 64],
                                      kernel_size=(3, 3),
                                      strides=(2, 2),
                                      is_training=False)
            self.rec_style = rec_style

        # Discriminator scope
        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE) as scope:
            self.real_logit = discriminator(content_inputs,
                                            reference_mel_pos,
                                            is_training=is_training)
            self.fake_logit_pos = discriminator(content_inputs,
                                                mel_outputs_pos,
                                                is_training=is_training)
            self.fake_logit_neg = discriminator(content_inputs,
                                                mel_outputs_neg,
                                                is_training=is_training)

        self.inputs = inputs
        self.input_lengths = input_lengths
        self.mel_outputs_pos = mel_outputs_pos
        self.mel_outputs_neg = mel_outputs_neg

        self.encoder_outputs = encoder_outputs

        self.style_embeddings_pos = style_embeddings_pos
        self.style_embeddings_neg = style_embeddings_neg

        self.linear_outputs_pos = linear_outputs_pos
        self.linear_outputs_neg = linear_outputs_neg

        self.alignments_pos = alignments_pos
        self.alignments_neg = alignments_neg
        self.mel_targets_pos = mel_targets_pos
        self.mel_targets_neg = mel_targets_neg
        self.linear_targets_pos = linear_targets_pos
        self.linear_targets_neg = linear_targets_neg
        self.reference_mel_pos = reference_mel_pos
        self.reference_mel_neg = reference_mel_neg
        log('Initialized Tacotron model. Dimensions: ')
        log('text embedding:          %d' % embedded_inputs.shape[-1])
Beispiel #13
0
    def initialize(self,
                   inputs,
                   input_lengths,
                   mel_targets=None,
                   linear_targets=None,
                   reference_mel=None):
        with tf.variable_scope('inference') as scope:
            is_training = linear_targets is not None
            is_teacher_force_generating = mel_targets is not None
            batch_size = tf.shape(inputs)[0]
            hp = self._hparams

            # Embeddings
            embedding_table = tf.get_variable(
                'text_embedding', [len(symbols), 256],
                dtype=tf.float32,
                initializer=tf.truncated_normal_initializer(stddev=0.5))
            embedded_inputs = tf.nn.embedding_lookup(embedding_table,
                                                     inputs)  # [N, T_in, 256]

            if hp.use_gst:
                #Global style tokens (GST)
                gst_tokens = tf.get_variable(
                    'style_tokens', [hp.num_gst, 256 // hp.num_heads],
                    dtype=tf.float32,
                    initializer=tf.truncated_normal_initializer(stddev=0.5))
                self.gst_tokens = gst_tokens

            # Encoder
            prenet_outputs = prenet(embedded_inputs,
                                    is_training)  # [N, T_in, 128]
            encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths,
                                           is_training)  # [N, T_in, 256]

            if is_training:
                reference_mel = mel_targets

            if reference_mel is not None:
                # Reference encoder
                refnet_outputs = reference_encoder(
                    reference_mel,
                    filters=[32, 32, 64, 64, 128, 128],
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    encoder_cell=GRUCell(128),
                    is_training=is_training)  # [N, 128]
                self.refnet_outputs = refnet_outputs

                if hp.use_gst:
                    # Style attention
                    style_attention = MultiheadAttention(
                        tf.tanh(tf.expand_dims(refnet_outputs,
                                               axis=1)),  # [N, 1, 128]
                        tf.tile(tf.expand_dims(gst_tokens, axis=0),
                                [batch_size, 1, 1
                                 ]),  # [N, hp.num_gst, 256/hp.num_heads]   
                        num_heads=hp.num_heads,
                        num_units=128,
                        attention_type=hp.style_att_type)

                    # Apply tanh to compress both encoder state and style embedding to the same scale.
                    style_embeddings = style_attention.multi_head_attention(
                    )  # [N, 1, 256]
                else:
                    style_embeddings = tf.expand_dims(refnet_outputs,
                                                      axis=1)  # [N, 1, 128]
            else:
                print("Use random weight for GST.")
                random_weights = tf.random_uniform([hp.num_heads, hp.num_gst],
                                                   maxval=1.0,
                                                   dtype=tf.float32)
                random_weights = tf.nn.softmax(random_weights,
                                               name="random_weights")
                style_embeddings = tf.matmul(random_weights,
                                             tf.nn.tanh(gst_tokens))
                style_embeddings = tf.reshape(
                    style_embeddings, [1, 1] +
                    [hp.num_heads * gst_tokens.get_shape().as_list()[1]])

            # Add style embedding to every text encoder state
            style_embeddings = tf.tile(
                style_embeddings,
                [1, shape_list(encoder_outputs)[1], 1])  # [N, T_in, 128]
            encoder_outputs = tf.concat([encoder_outputs, style_embeddings],
                                        axis=-1)

            # Attention
            attention_cell = AttentionWrapper(
                DecoderPrenetWrapper(GRUCell(256), is_training),
                BahdanauAttention(256,
                                  encoder_outputs,
                                  memory_sequence_length=input_lengths),
                alignment_history=True,
                output_attention=False)  # [N, T_in, 256]

            # Concatenate attention context vector and RNN cell output.
            concat_cell = ConcatOutputAndAttentionWrapper(attention_cell)

            # Decoder (layers specified bottom to top):
            decoder_cell = MultiRNNCell([
                OutputProjectionWrapper(concat_cell, 256),
                ResidualWrapper(ZoneoutWrapper(LSTMCell(256), 0.1)),
                ResidualWrapper(ZoneoutWrapper(LSTMCell(256), 0.1))
            ],
                                        state_is_tuple=True)  # [N, T_in, 256]

            # Project onto r mel spectrograms (predict r outputs at each RNN step):
            output_cell = OutputProjectionWrapper(
                decoder_cell, hp.num_mels * hp.outputs_per_step)
            decoder_init_state = output_cell.zero_state(batch_size=batch_size,
                                                        dtype=tf.float32)

            if is_training or is_teacher_force_generating:
                helper = TrainingHelper(inputs, mel_targets, hp.num_mels,
                                        hp.outputs_per_step)
            else:
                helper = TestHelper(batch_size, hp.num_mels,
                                    hp.outputs_per_step)

            (decoder_outputs,
             _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
                 BasicDecoder(output_cell, helper, decoder_init_state),
                 maximum_iterations=hp.max_iters)  # [N, T_out/r, M*r]

            # Reshape outputs to be one output per entry
            mel_outputs = tf.reshape(
                decoder_outputs,
                [batch_size, -1, hp.num_mels])  # [N, T_out, M]

            # Add post-processing CBHG:
            post_outputs = post_cbhg(mel_outputs, hp.num_mels,
                                     is_training)  # [N, T_out, 256]
            linear_outputs = tf.layers.dense(post_outputs,
                                             hp.num_freq)  # [N, T_out, F]

            # # Grab alignments from the final decoder state:
            # alignments = tf.transpose(final_decoder_state[0].alignment_history.stack(), [1, 2, 0])

            self.inputs = inputs
            self.input_lengths = input_lengths
            self.mel_outputs = mel_outputs
            self.encoder_outputs = encoder_outputs
            self.style_embeddings = style_embeddings
            self.linear_outputs = linear_outputs
            # self.alignments = alignments
            self.mel_targets = mel_targets
            self.linear_targets = linear_targets
            self.reference_mel = reference_mel
Beispiel #14
0
    def initialize(self,
                   inputs,
                   input_lengths,
                   mel_targets=None,
                   linear_targets=None,
                   reference_mels=None):
        """
        Initializes the model for inference.

        Sets "mel_outputs", "linear_outputs", and "alignments" fields.

        Args:
            inputs: int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of
            steps in the input time series, and values are character IDs
            input_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths
            of each sequence in inputs.
            mel_targets: float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number
            of steps in the output time series, M is num_mels, and values are entries in the mel
            spectrogram. Only needed for training.
            linear_targets: float32 Tensor with shape [N, T_out, F] where N is batch_size, T_out is number
            of steps in the output time series, F is num_freq, and values are entries in the linear
            spectrogram. Only needed for training.
            reference_mels: the reference encoder inputs
        """
        with tf.variable_scope('inference') as scope:
            is_training = linear_targets is not None
            batch_size = tf.shape(inputs)[0]
            hp = self._hparams

            # Embeddings for character inputs: [N, T_in]
            embedding_table = tf.get_variable(
                'embedding', [len(symbols), 256],
                dtype=tf.float32,
                initializer=tf.truncated_normal_initializer(stddev=0.5))
            embedded_inputs = tf.nn.embedding_lookup(embedding_table,
                                                     inputs)  # [N, T_in, 256]

            # Encoder
            prenet_outputs = prenet(embedded_inputs,
                                    is_training)  # [N, T_in, 128]
            encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths,
                                           is_training)  # [N, T_in, 256]

            # Whether use Global Style Token
            if is_training:
                reference_mels = mel_targets

            if hp.use_gst:
                gst_tokens = tf.get_variable(
                    'style_tokens', [hp.num_tokens, 256 // hp.num_heads],
                    dtype=tf.float32,
                    initializer=tf.truncated_normal_initializer(stddev=0.5))
                self.gst_tokens = gst_tokens

                # Reference Encoder
                _, reference_encoder_outputs = reference_encoder(
                    inputs=reference_mels,
                    filters=[32, 32, 64, 64, 128, 128],
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    is_training=is_training)  # [N, 128]
                # Style Token Layer Using Multi-Head Attention
                style_attention = MultiHeadAttention(
                    num_heads=hp.num_heads,
                    num_units=128,
                    attention_type=hp.attention_type)
                style_embedding = tf.nn.tanh(
                    style_attention.multi_head_attention(
                        query=tf.expand_dims(reference_encoder_outputs,
                                             axis=1),  # [N, 1, 128]
                        value=tf.tile(tf.expand_dims(gst_tokens, axis=0),
                                      [batch_size, 1, 1
                                       ])  # [N, num_tokens, 256/num_heads]
                    ))  # [N, 1, 128]

                # add style embedding to encoder outputs
                T_in = shape_list(encoder_outputs)[1]
                style_embedding = tf.tile(style_embedding, [1, T_in, 1])
                encoder_outputs = tf.concat([encoder_outputs, style_embedding],
                                            axis=-1)

            # Attention
            attention_cell = AttentionWrapper(
                DecoderPrenetWrapper(GRUCell(256), is_training),
                BahdanauAttention(256, encoder_outputs),
                alignment_history=True,
                output_attention=False)  # [N, T_in, 256]

            # Concatenate attention context vector and RNN cell output into a 512D vector.
            concat_cell = ConcatOutputAndAttentionWrapper(
                attention_cell)  # [N, T_in, 512]

            # Decoder (layers specified bottom to top),
            # fix decoder cell from gru to lstm and add zoneout
            decoder_cell = MultiRNNCell([
                OutputProjectionWrapper(concat_cell, 256),
                ResidualWrapper(ZoneoutWrapper(LSTMCell(256), 0.1,
                                               is_training)),
                ResidualWrapper(ZoneoutWrapper(LSTMCell(256), 0.1,
                                               is_training))
            ],
                                        state_is_tuple=True)  # [N, T_in, 256]

            # Project onto r mel spectrograms (predict r outputs at each RNN step):
            output_cell = OutputProjectionWrapper(
                decoder_cell, hp.num_mels * hp.outputs_per_step)
            decoder_init_state = output_cell.zero_state(batch_size=batch_size,
                                                        dtype=tf.float32)

            if is_training:
                helper = TacoTrainingHelper(inputs, mel_targets, hp.num_mels,
                                            hp.outputs_per_step)
            else:
                helper = TacoTestHelper(batch_size, hp.num_mels,
                                        hp.outputs_per_step)

            (decoder_outputs,
             _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
                 BasicDecoder(output_cell, helper, decoder_init_state),
                 maximum_iterations=hp.max_iters)  # [N, T_out/r, M*r]

            # Reshape outputs to be one output per entry
            mel_outputs = tf.reshape(
                decoder_outputs,
                [batch_size, -1, hp.num_mels])  # [N, T_out, M]

            # Add post-processing CBHG:
            post_outputs = post_cbhg(mel_outputs, hp.num_mels,
                                     is_training)  # [N, T_out, 256]
            linear_outputs = tf.layers.dense(post_outputs,
                                             hp.num_freq)  # [N, T_out, F]

            # Grab alignments from the final decoder state:
            alignments = tf.transpose(
                final_decoder_state[0].alignment_history.stack(), [1, 2, 0])

            self.inputs = inputs
            self.input_lengths = input_lengths
            self.mel_outputs = mel_outputs
            self.linear_outputs = linear_outputs
            self.alignments = alignments
            self.mel_targets = mel_targets
            self.linear_targets = linear_targets
            log('Initialized Tacotron model. Dimensions: ')
            log('  embedding:               %d' % embedded_inputs.shape[-1])
            log('  prenet out:              %d' % prenet_outputs.shape[-1])
            log('  encoder out:             %d' % encoder_outputs.shape[-1])
            log('  attention out:           %d' % attention_cell.output_size)
            log('  concat attn & out:       %d' % concat_cell.output_size)
            log('  decoder cell out:        %d' % decoder_cell.output_size)
            log('  decoder out (%d frames):  %d' %
                (hp.outputs_per_step, decoder_outputs.shape[-1]))
            log('  decoder out (1 frame):   %d' % mel_outputs.shape[-1])
            log('  postnet out:             %d' % post_outputs.shape[-1])
            log('  linear out:              %d' % linear_outputs.shape[-1])
Beispiel #15
0
    def initialize(self,
                   inputs,
                   input_lengths,
                   mel_targets=None,
                   linear_targets=None,
                   reference_mel=None,
                   global_step=None,
                   stop_token_targets=None):
        """
        Initializes the model for inference

        sets "mel_outputs" and "alignments" fields.

        Args:
            - inputs: int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of
              steps in the input time series, and values are character IDs
            - input_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths
            of each sequence in inputs.
            - mel_targets: float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number
            of steps in the output time series, M is num_mels, and values are entries in the mel
            spectrogram. Only needed for training.
        """
        is_training = linear_targets is not None

        with tf.variable_scope('inference') as scope:
            batch_size = tf.shape(inputs)[0]
            hp = self._hparams
            assert hp.tacotron_teacher_forcing_mode in ('constant',
                                                        'scheduled')
            if hp.tacotron_teacher_forcing_mode == 'scheduled' and is_training:
                assert global_step is not None

            # Embeddings ==> [batch_size, sequence_length, embedding_dim]
            embedding_table = tf.get_variable('inputs_embedding',
                                              [len(symbols), hp.embed_depth],
                                              dtype=tf.float32)
            embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs)

            if hp.use_gst:
                # Global style tokens (GST)
                gst_tokens = tf.get_variable(
                    'style_tokens',
                    [hp.num_gst, hp.style_embed_depth // hp.num_heads],
                    dtype=tf.float32,
                    initializer=tf.truncated_normal_initializer(stddev=0.5))
                self.gst_tokens = gst_tokens

            # Encoder Cell ==> [batch_size, encoder_steps, encoder_lstm_units]
            encoder_cell = TacotronEncoderCell(
                EncoderConvolutions(is_training,
                                    hparams=hp,
                                    scope='encoder_convolutions'),
                EncoderRNN(is_training,
                           size=hp.encoder_lstm_units,
                           zoneout=hp.tacotron_zoneout_rate,
                           scope='encoder_LSTM'))

            encoder_outputs = encoder_cell(embedded_inputs, input_lengths)

            # For shape visualization purpose
            enc_conv_output_shape = encoder_cell.conv_output_shape

            # Decoder Parts
            if is_training:
                reference_mel = mel_targets

            if reference_mel is not None:
                # Reference encoder
                refnet_outputs = reference_encoder(
                    reference_mel,
                    filters=hp.reference_filters,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    encoder_cell=GRUCell(hp.reference_depth),
                    is_training=is_training)  # [N, 128]
                self.refnet_outputs = refnet_outputs

                if hp.use_gst:
                    # Style attention
                    style_attention = MultiheadAttention(
                        tf.expand_dims(refnet_outputs, axis=1),  # [N, 1, 128]
                        tf.tanh(
                            tf.tile(tf.expand_dims(gst_tokens, axis=0),
                                    [batch_size, 1, 1])),
                        # [N, hp.num_gst, 256/hp.num_heads]
                        num_heads=hp.num_heads,
                        num_units=hp.style_att_dim,
                        attention_type=hp.style_att_type)

                    style_embeddings = style_attention.multi_head_attention(
                    )  # [N, 1, 256]
                else:
                    style_embeddings = tf.expand_dims(refnet_outputs,
                                                      axis=1)  # [N, 1, 128]
            else:
                print("Use random weight for GST.")
                random_weights = tf.random_uniform([hp.num_heads, hp.num_gst],
                                                   maxval=1.0,
                                                   dtype=tf.float32)
                random_weights = tf.nn.softmax(random_weights,
                                               name="random_weights")
                style_embeddings = tf.matmul(random_weights,
                                             tf.nn.tanh(gst_tokens))
                style_embeddings = tf.reshape(
                    style_embeddings, [1, 1] +
                    [hp.num_heads * gst_tokens.get_shape().as_list()[1]])

            # Add style embedding to every text encoder state
            style_embeddings = tf.tile(
                style_embeddings,
                [1, shape_list(encoder_outputs)[1], 1])  # [N, T_in, 128]
            encoder_outputs = tf.concat([encoder_outputs, style_embeddings],
                                        axis=-1)

            # Attention Mechanism
            attention_mechanism = LocationSensitiveAttention(
                hp.attention_depth,
                encoder_outputs,
                hparams=hp,
                mask_encoder=False,
                memory_sequence_length=input_lengths,
                smoothing=False,
                cumulate_weights=hp.cumulative_weights)
            # Decoder LSTM Cells
            decoder_lstm = DecoderRNN(is_training,
                                      layers=hp.decoder_layers,
                                      size=hp.decoder_lstm_units,
                                      zoneout=hp.tacotron_zoneout_rate,
                                      scope='decoder_lstm')
            # Frames Projection layer
            frame_projection = FrameProjection(hp.num_mels *
                                               hp.outputs_per_step,
                                               scope='linear_transform')
            # <stop_token> projection layer
            stop_projection = StopProjection(is_training,
                                             shape=hp.outputs_per_step,
                                             scope='stop_token_projection')

            # Attention Decoder Prenet
            prenet = Prenet(is_training,
                            layers_sizes=hp.prenet_layers,
                            drop_rate=hp.tacotron_dropout_rate,
                            scope='prenet')

            # Decoder Cell ==> [batch_size, decoder_steps, num_mels * r] (after decoding)
            decoder_cell = TacotronDecoderCell(prenet, attention_mechanism,
                                               decoder_lstm, frame_projection,
                                               stop_projection)

            # Define the helper for our decoder
            if is_training:
                self.helper = TacoTrainingHelper(batch_size, mel_targets,
                                                 stop_token_targets, hp,
                                                 global_step)
            else:
                self.helper = TacoTestHelper(batch_size, hp)

            # initial decoder state
            decoder_init_state = decoder_cell.zero_state(batch_size=batch_size,
                                                         dtype=tf.float32)

            # Only use max iterations at synthesis time
            max_iters = hp.max_iters if not is_training else None

            # Decode
            (frames_prediction, stop_token_prediction,
             _), final_decoder_state, _ = dynamic_decode(
                 CustomDecoder(decoder_cell, self.helper, decoder_init_state),
                 impute_finished=False,
                 maximum_iterations=max_iters,
                 swap_memory=False)

            # Reshape outputs to be one output per entry
            # ==> [batch_size, non_reduced_decoder_steps (decoder_steps * r), num_mels]
            decoder_output = tf.reshape(frames_prediction,
                                        [batch_size, -1, hp.num_mels])
            stop_token_prediction = tf.reshape(stop_token_prediction,
                                               [batch_size, -1])

            # Postnet
            postnet = Postnet(is_training,
                              hparams=hp,
                              scope='postnet_convolutions')

            # Compute residual using post-net ==> [batch_size, decoder_steps * r, postnet_channels]
            residual = postnet(decoder_output)

            # Project residual to same dimension as mel spectrogram
            # ==> [batch_size, decoder_steps * r, num_mels]
            residual_projection = FrameProjection(hp.num_mels,
                                                  scope='postnet_projection')
            projected_residual = residual_projection(residual)

            # Compute the mel spectrogram
            mel_outputs = decoder_output + projected_residual

            # Based on https://github.com/keithito/tacotron/blob/tacotron2-work-in-progress/models/tacotron.py
            # Post-processing Network to map mels to linear spectrograms using same architecture as the encoder
            post_processing_cell = TacotronEncoderCell(
                EncoderConvolutions(is_training,
                                    hparams=hp,
                                    scope='post_processing_convolutions'),
                EncoderRNN(is_training,
                           size=hp.encoder_lstm_units,
                           zoneout=hp.tacotron_zoneout_rate,
                           scope='post_processing_LSTM'))

            expand_outputs = post_processing_cell(mel_outputs)
            linear_outputs = FrameProjection(
                hp.num_freq,
                scope='post_processing_projection')(expand_outputs)

            # Grab alignments from the final decoder state
            alignments = tf.transpose(
                final_decoder_state.alignment_history.stack(), [1, 2, 0])

            self.inputs = inputs
            self.input_lengths = input_lengths
            self.decoder_output = decoder_output
            self.alignments = alignments
            self.stop_token_prediction = stop_token_prediction
            self.stop_token_targets = stop_token_targets
            self.mel_outputs = mel_outputs
            self.linear_outputs = linear_outputs
            self.linear_targets = linear_targets
            self.mel_targets = mel_targets
            self.reference_mel = reference_mel
            log('Initialized Tacotron model. Dimensions (? = dynamic shape): ')
            log('  Train mode:               {}'.format(is_training))
            log('  embedding:                {}'.format(embedded_inputs.shape))
            log('  enc conv out:             {}'.format(enc_conv_output_shape))
            log('  encoder out:              {}'.format(encoder_outputs.shape))
            log('  decoder out:              {}'.format(decoder_output.shape))
            log('  residual out:             {}'.format(residual.shape))
            log('  projected residual out:   {}'.format(
                projected_residual.shape))
            log('  mel out:                  {}'.format(mel_outputs.shape))
            log('  linear out:               {}'.format(linear_outputs.shape))
            log('  <stop_token> out:         {}'.format(
                stop_token_prediction.shape))