def causal_conv_bn_actv(layer_type, name, inputs, filters, kernel_size, activation_fn, strides, padding, regularizer, training, data_format, bn_momentum, bn_epsilon, dilation=1): """ Defines a single dilated causal convolutional layer with batch norm """ block = conv_bn_actv(layer_type=layer_type, name=name, inputs=inputs, filters=filters, kernel_size=kernel_size, activation_fn=activation_fn, strides=strides, padding=padding, regularizer=regularizer, training=training, data_format=data_format, bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, dilation=dilation) # pad the left side of the time-series with an amount of zeros based on the # dilation rate block = tf.pad(block, [[0, 0], [dilation * (kernel_size - 1), 0], [0, 0]]) return block
def _decode(self, input_dict): """ Decodes representation into data Args: input_dict (dict): Python dictionary with inputs to decoder. Must define: * src_inputs - decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * src_lengths - decoder input lengths Tensor of shape [batch_size] * tgt_inputs - Only during training. labels Tensor of the shape [batch_size, time, num_features] or [time, batch_size, num_features] * stop_token_inputs - Only during training. labels Tensor of the shape [batch_size, time, 1] or [time, batch_size, 1] * tgt_lengths - Only during training. labels lengths Tensor of the shape [batch_size] Returns: dict: A python dictionary containing: * outputs - array containing: * decoder_output - tensor of shape [batch_size, time, num_features] or [time, batch_size, num_features]. Spectrogram representation learned by the decoder rnn * spectrogram_prediction - tensor of shape [batch_size, time, num_features] or [time, batch_size, num_features]. Spectrogram containing the residual corrections from the postnet if enabled * alignments - tensor of shape [batch_size, time, memory_size] or [time, batch_size, memory_size]. The alignments learned by the attention layer * stop_token_prediction - tensor of shape [batch_size, time, 1] or [time, batch_size, 1]. The stop token predictions * final_sequence_lengths - tensor of shape [batch_size] * stop_token_predictions - tensor of shape [batch_size, time, 1] or [time, batch_size, 1]. The stop token predictions for use inside the loss function. """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_length'] if self._mode == "train": spec = input_dict['target_tensors'][0] if 'target_tensors' in \ input_dict else None spec_length = input_dict['target_tensors'][2] if 'target_tensors' in \ input_dict else None _batch_size = encoder_outputs.get_shape().as_list()[0] training = (self._mode == "train") regularizer = self.params.get('regularizer', None) if self.params.get('enable_postnet', True): if "postnet_conv_layers" not in self.params: raise ValueError( "postnet_conv_layers must be passed from config file if postnet is" "enabled" ) if self._both: num_audio_features = self._n_feats["mel"] if self._mode == "train": spec, _ = tf.split( spec, [self._n_feats['mel'], self._n_feats['magnitude']], axis=2 ) else: num_audio_features = self._n_feats output_projection_layer = tf.layers.Dense( name="output_proj", units=num_audio_features, use_bias=True, ) stop_token_projection_layer = tf.layers.Dense( name="stop_token_proj", units=1, use_bias=True, ) prenet = None if self.params.get('enable_prenet', True): prenet = Prenet( self.params.get('prenet_units', 256), self.params.get('prenet_layers', 2), self.params.get("prenet_activation", tf.nn.relu), self.params["dtype"] ) cell_params = {} cell_params["num_units"] = self.params['decoder_cell_units'] decoder_cells = [ single_cell( cell_class=self.params['decoder_cell_type'], cell_params=cell_params, zoneout_prob=self.params.get("zoneout_prob", 0.), dp_output_keep_prob=1.-self.params.get("dropout_prob", 0.1), training=training, ) for _ in range(self.params['decoder_layers']) ] if self.params['attention_type'] is not None: attention_mechanism = self._build_attention( encoder_outputs, enc_src_lengths, self.params.get("attention_bias", False) ) attention_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells) attentive_cell = AttentionWrapper( cell=attention_cell, attention_mechanism=attention_mechanism, alignment_history=True, output_attention="both", ) decoder_cell = attentive_cell if self.params['attention_type'] is None: decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells) if self._mode == "train": train_and_not_sampling = True helper = TacotronTrainingHelper( inputs=spec, sequence_length=spec_length, prenet=None, model_dtype=self.params["dtype"], mask_decoder_sequence=self.params.get("mask_decoder_sequence", True) ) elif self._mode == "eval" or self._mode == "infer": train_and_not_sampling = False inputs = tf.zeros( (_batch_size, 1, num_audio_features), dtype=self.params["dtype"] ) helper = TacotronHelper( inputs=inputs, prenet=None, mask_decoder_sequence=self.params.get("mask_decoder_sequence", True) ) else: raise ValueError("Unknown mode for decoder: {}".format(self._mode)) decoder = TacotronDecoder( decoder_cell=decoder_cell, helper=helper, initial_decoder_state=decoder_cell.zero_state( _batch_size, self.params["dtype"] ), attention_type=self.params["attention_type"], spec_layer=output_projection_layer, stop_token_layer=stop_token_projection_layer, prenet=prenet, dtype=self.params["dtype"], train=train_and_not_sampling ) if self._mode == 'train': maximum_iterations = tf.reduce_max(spec_length) else: maximum_iterations = tf.reduce_max(enc_src_lengths) * 10 outputs, final_state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode( # outputs, final_state, sequence_lengths, final_inputs = dynamic_decode( decoder=decoder, impute_finished=False, maximum_iterations=maximum_iterations, swap_memory=self.params.get("use_swap_memory", False), output_time_major=self.params.get("time_major", False), parallel_iterations=self.params.get("parallel_iterations", 32) ) decoder_output = outputs.rnn_output stop_token_logits = outputs.stop_token_output with tf.variable_scope("decoder"): # If we are in train and doing sampling, we need to do the projections if train_and_not_sampling: decoder_spec_output = output_projection_layer(decoder_output) stop_token_logits = stop_token_projection_layer(decoder_spec_output) decoder_output = decoder_spec_output ## Add the post net ## if self.params.get('enable_postnet', True): dropout_keep_prob = self.params.get('postnet_keep_dropout_prob', 0.5) top_layer = decoder_output for i, conv_params in enumerate(self.params['postnet_conv_layers']): ch_out = conv_params['num_channels'] kernel_size = conv_params['kernel_size'] # [time, freq] strides = conv_params['stride'] padding = conv_params['padding'] activation_fn = conv_params['activation_fn'] if ch_out == -1: if self._both: ch_out = self._n_feats["mel"] else: ch_out = self._n_feats top_layer = conv_bn_actv( layer_type="conv1d", name="conv{}".format(i + 1), inputs=top_layer, filters=ch_out, kernel_size=kernel_size, activation_fn=activation_fn, strides=strides, padding=padding, regularizer=regularizer, training=training, data_format=self.params.get('postnet_data_format', 'channels_last'), bn_momentum=self.params.get('postnet_bn_momentum', 0.1), bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5), ) top_layer = tf.layers.dropout( top_layer, rate=1. - dropout_keep_prob, training=training ) else: top_layer = tf.zeros( [ _batch_size, maximum_iterations, outputs.rnn_output.get_shape()[-1] ], dtype=self.params["dtype"] ) if regularizer and training: vars_to_regularize = [] vars_to_regularize += attentive_cell.trainable_variables vars_to_regularize += attention_mechanism.memory_layer.trainable_variables vars_to_regularize += output_projection_layer.trainable_variables vars_to_regularize += stop_token_projection_layer.trainable_variables for weights in vars_to_regularize: if "bias" not in weights.name: # print("Added regularizer to {}".format(weights.name)) if weights.dtype.base_dtype == tf.float16: tf.add_to_collection( 'REGULARIZATION_FUNCTIONS', (weights, regularizer) ) else: tf.add_to_collection( ops.GraphKeys.REGULARIZATION_LOSSES, regularizer(weights) ) if self.params.get('enable_prenet', True): prenet.add_regularization(regularizer) if self.params['attention_type'] is not None: alignments = tf.transpose( final_state.alignment_history.stack(), [1, 0, 2] ) else: alignments = tf.zeros([_batch_size, _batch_size, _batch_size]) spectrogram_prediction = decoder_output + top_layer if self._both: mag_spec_prediction = spectrogram_prediction mag_spec_prediction = conv_bn_actv( layer_type="conv1d", name="conv_0", inputs=mag_spec_prediction, filters=256, kernel_size=4, activation_fn=tf.nn.relu, strides=1, padding="SAME", regularizer=regularizer, training=training, data_format=self.params.get('postnet_data_format', 'channels_last'), bn_momentum=self.params.get('postnet_bn_momentum', 0.1), bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5), ) mag_spec_prediction = conv_bn_actv( layer_type="conv1d", name="conv_1", inputs=mag_spec_prediction, filters=512, kernel_size=4, activation_fn=tf.nn.relu, strides=1, padding="SAME", regularizer=regularizer, training=training, data_format=self.params.get('postnet_data_format', 'channels_last'), bn_momentum=self.params.get('postnet_bn_momentum', 0.1), bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5), ) if self._model.get_data_layer()._exp_mag: mag_spec_prediction = tf.exp(mag_spec_prediction) mag_spec_prediction = tf.layers.conv1d( mag_spec_prediction, self._n_feats["magnitude"], 1, name="post_net_proj", use_bias=False, ) else: mag_spec_prediction = tf.zeros([_batch_size, _batch_size, _batch_size]) stop_token_prediction = tf.sigmoid(stop_token_logits) outputs = [ decoder_output, spectrogram_prediction, alignments, stop_token_prediction, sequence_lengths, mag_spec_prediction ] return { 'outputs': outputs, 'stop_token_prediction': stop_token_logits, }
def _encode(self, input_dict): """Creates TensorFlow graph for Tacotron-2 like encoder. Args: input_dict (dict): dictionary with inputs. Must define: source_tensors - array containing [ * source_sequence: tensor of shape [batch_size, sequence length] * src_length: tensor of shape [batch_size] ] Returns: dict: A python dictionary containing: * outputs - tensor containing the encoded text to be passed to the attention layer * src_length - the length of the encoded text """ text = input_dict['source_tensors'][0] text_len = input_dict['source_tensors'][1] training = (self._mode == "train") regularizer = self.params.get('regularizer', None) data_format = self.params.get('data_format', 'channels_last') src_vocab_size = self._model.get_data_layer().params['src_vocab_size'] zoneout_prob = self.params.get('zoneout_prob', 0.) # if src_vocab_size % 8 != 0: # src_vocab_size += 8 - (src_vocab_size % 8) # ----- Embedding layer ----------------------------------------------- enc_emb_w = tf.get_variable( name="EncoderEmbeddingMatrix", shape=[src_vocab_size, self.params['src_emb_size']], dtype=self.params['dtype'], # initializer=tf.random_normal_initializer() ) embedded_inputs = tf.cast(tf.nn.embedding_lookup( enc_emb_w, text, ), self.params['dtype']) # ----- Convolutional layers ----------------------------------------------- input_layer = embedded_inputs if data_format == 'channels_last': top_layer = input_layer else: top_layer = tf.transpose(input_layer, [0, 2, 1]) for i, conv_params in enumerate(self.params['conv_layers']): ch_out = conv_params['num_channels'] kernel_size = conv_params['kernel_size'] # [time, freq] strides = conv_params['stride'] padding = conv_params['padding'] if padding == "VALID": text_len = (text_len - kernel_size[0] + strides[0]) // strides[0] else: text_len = (text_len + strides[0] - 1) // strides[0] top_layer = conv_bn_actv( layer_type="conv1d", name="conv{}".format(i + 1), inputs=top_layer, filters=ch_out, kernel_size=kernel_size, activation_fn=self.params['activation_fn'], strides=strides, padding=padding, regularizer=regularizer, training=training, data_format=data_format, bn_momentum=self.params.get('bn_momentum', 0.1), bn_epsilon=self.params.get('bn_epsilon', 1e-5), ) top_layer = tf.layers.dropout(top_layer, rate=self.params["cnn_dropout_prob"], training=training) if data_format == 'channels_first': top_layer = tf.transpose(top_layer, [0, 2, 1]) # ----- RNN --------------------------------------------------------------- num_rnn_layers = self.params['num_rnn_layers'] if num_rnn_layers > 0: cell_params = {} cell_params["num_units"] = self.params['rnn_cell_dim'] rnn_type = self.params['rnn_type'] rnn_input = top_layer rnn_vars = [] if self.params["use_cudnn_rnn"]: if self._mode == "infer": cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell( cell_params["num_units"]) cells_fw = [cell() for _ in range(1)] cells_bw = [cell() for _ in range(1)] (top_layer, _, _) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, rnn_input, sequence_length=text_len, dtype=rnn_input.dtype, time_major=False) else: all_cudnn_classes = [ i[1] for i in inspect.getmembers( tf.contrib.cudnn_rnn, inspect.isclass) ] if not rnn_type in all_cudnn_classes: raise TypeError("rnn_type must be a Cudnn RNN class") if zoneout_prob != 0.: raise ValueError( "Zoneout is currently not supported for cudnn rnn classes" ) rnn_input = tf.transpose(top_layer, [1, 0, 2]) if self.params['rnn_unidirectional']: direction = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else: direction = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION rnn_block = rnn_type(num_layers=num_rnn_layers, num_units=cell_params["num_units"], direction=direction, dtype=rnn_input.dtype, name="cudnn_rnn") rnn_block.build(rnn_input.get_shape()) top_layer, _ = rnn_block(rnn_input) top_layer = tf.transpose(top_layer, [1, 0, 2]) rnn_vars += rnn_block.trainable_variables else: multirnn_cell_fw = tf.nn.rnn_cell.MultiRNNCell([ single_cell(cell_class=rnn_type, cell_params=cell_params, zoneout_prob=zoneout_prob, training=training, residual_connections=False) for _ in range(num_rnn_layers) ]) rnn_vars += multirnn_cell_fw.trainable_variables if self.params['rnn_unidirectional']: top_layer, _ = tf.nn.dynamic_rnn( cell=multirnn_cell_fw, inputs=rnn_input, sequence_length=text_len, dtype=rnn_input.dtype, time_major=False, ) else: multirnn_cell_bw = tf.nn.rnn_cell.MultiRNNCell([ single_cell(cell_class=rnn_type, cell_params=cell_params, zoneout_prob=zoneout_prob, training=training, residual_connections=False) for _ in range(num_rnn_layers) ]) top_layer, _ = tf.nn.bidirectional_dynamic_rnn( cell_fw=multirnn_cell_fw, cell_bw=multirnn_cell_bw, inputs=rnn_input, sequence_length=text_len, dtype=rnn_input.dtype, time_major=False) # concat 2 tensors [B, T, n_cell_dim] --> [B, T, 2*n_cell_dim] top_layer = tf.concat(top_layer, 2) rnn_vars += multirnn_cell_bw.trainable_variables if regularizer and training: cell_weights = [] cell_weights += rnn_vars cell_weights += [enc_emb_w] for weights in cell_weights: if "bias" not in weights.name: # print("Added regularizer to {}".format(weights.name)) if weights.dtype.base_dtype == tf.float16: tf.add_to_collection('REGULARIZATION_FUNCTIONS', (weights, regularizer)) else: tf.add_to_collection( ops.GraphKeys.REGULARIZATION_LOSSES, regularizer(weights)) # -- end of rnn------------------------------------------------------------ top_layer = tf.layers.dropout(top_layer, rate=self.params["rnn_dropout_prob"], training=training) outputs = top_layer return {'outputs': outputs, 'src_length': text_len}
def _encode(self, input_dict): """Creates TensorFlow graph for DeepSpeech-2 like encoder. Args: input_dict (dict): input dictionary that has to contain the following fields:: input_dict = { "source_tensors": [ src_sequence (shape=[batch_size, sequence length, num features]), src_length (shape=[batch_size]) ] } Returns: dict: dictionary with the following tensors:: { 'outputs': hidden state, shape=[batch_size, sequence length, n_hidden] 'src_length': tensor, shape=[batch_size] } """ source_sequence, src_length = input_dict['source_tensors'] training = (self._mode == "train") dropout_keep_prob = self.params[ 'dropout_keep_prob'] if training else 1.0 regularizer = self.params.get('regularizer', None) data_format = self.params.get('data_format', 'channels_last') bn_momentum = self.params.get('bn_momentum', 0.99) bn_epsilon = self.params.get('bn_epsilon', 1e-3) input_layer = tf.expand_dims(source_sequence, axis=-1) # BTFC # print("<<< input :", input_layer.get_shape().as_list()) batch_size = input_layer.get_shape().as_list()[0] freq = input_layer.get_shape().as_list()[2] # supported data_formats: # BTFC = channel_last (legacy) # BCTF = channel_first(legacy) # BFTC # BCFT if data_format == 'channels_last' or data_format == 'BTFC': layout = 'BTFC' dformat = 'channels_last' elif data_format == 'channels_first' or data_format == 'BCTF': layout = 'BCTF' dformat = 'channels_first' elif data_format == 'BFTC': layout = 'BFTC' dformat = 'channels_last' elif data_format == 'BCFT': layout = 'BCFT' dformat = 'channels_first' else: print( "WARNING: unsupported data format: will use channels_last (BTFC) instead" ) layout = 'BTFC' dformat = 'channels_last' #input_layer is BTFC if layout == 'BCTF': top_layer = tf.transpose(input_layer, [0, 3, 1, 2]) elif layout == 'BFTC': top_layer = tf.transpose(input_layer, [0, 2, 1, 3]) elif layout == 'BCFT': top_layer = tf.transpose(input_layer, [0, 3, 2, 1]) else: top_layer = input_layer # print("<<< pre-conv:", top_layer.get_shape().as_list()) # ----- Convolutional layers --------------------------------------------- conv_layers = self.params['conv_layers'] for idx_conv in range(len(conv_layers)): ch_out = conv_layers[idx_conv]['num_channels'] kernel_size = conv_layers[idx_conv]['kernel_size'] # [T,F] format strides = conv_layers[idx_conv]['stride'] # [T,F] format padding = conv_layers[idx_conv]['padding'] if padding == "VALID": src_length = (src_length - kernel_size[0] + strides[0]) // strides[0] freq = (freq - kernel_size[1] + strides[1]) // strides[1] else: src_length = (src_length + strides[0] - 1) // strides[0] freq = (freq + strides[1] - 1) // strides[1] if layout == 'BFTC' or layout == 'BCFT': kernel_size = kernel_size[::-1] strides = strides[::-1] # print(kernel_size, strides) top_layer = conv_bn_actv( layer_type="conv2d", name="conv{}".format(idx_conv + 1), inputs=top_layer, filters=ch_out, kernel_size=kernel_size, activation_fn=self.params['activation_fn'], strides=strides, padding=padding, regularizer=regularizer, training=training, data_format=dformat, bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, ) # print(idx_conv, "++++", top_layer.get_shape().as_list()) # convert layout --> BTFC # if data_format == 'channels_first': # top_layer = tf.transpose(top_layer, [0, 2, 3, 1]) if layout == 'BCTF': # BCTF --> BTFC top_layer = tf.transpose(top_layer, [0, 2, 3, 1]) elif layout == 'BFTC': # BFTC --> BTFC top_layer = tf.transpose(top_layer, [0, 2, 1, 3]) elif layout == 'BCFT': # BCFT --> BTFC top_layer = tf.transpose(top_layer, [0, 3, 2, 1]) # print(">>> post-conv:", top_layer.get_shape().as_list()) # reshape to [B, T, FxC] f = top_layer.get_shape().as_list()[2] c = top_layer.get_shape().as_list()[3] fc = f * c top_layer = tf.reshape(top_layer, [batch_size, -1, fc]) # ----- RNN --------------------------------------------------------------- num_rnn_layers = self.params['num_rnn_layers'] if num_rnn_layers > 0: rnn_cell_dim = self.params['rnn_cell_dim'] rnn_type = self.params['rnn_type'] if self.params['use_cudnn_rnn']: # reshape to [B, T, C] --> [T, B, C] rnn_input = tf.transpose(top_layer, [1, 0, 2]) if self.params['rnn_unidirectional']: direction = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else: direction = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION if rnn_type == "cudnn_gru" or rnn_type == "gru": # pylint: disable=no-member rnn_block = tf.contrib.cudnn_rnn.CudnnGRU( num_layers=num_rnn_layers, num_units=rnn_cell_dim, direction=direction, dropout=1.0 - dropout_keep_prob, dtype=rnn_input.dtype, name="cudnn_gru", ) elif rnn_type == "cudnn_lstm" or rnn_type == "lstm": # pylint: disable=no-member rnn_block = tf.contrib.cudnn_rnn.CudnnLSTM( num_layers=num_rnn_layers, num_units=rnn_cell_dim, direction=direction, dropout=1.0 - dropout_keep_prob, dtype=rnn_input.dtype, name="cudnn_lstm", ) else: raise ValueError( "{} is not a valid rnn_type for cudnn_rnn layers". format(rnn_type)) top_layer, state = rnn_block(rnn_input) top_layer = tf.transpose(top_layer, [1, 0, 2]) else: rnn_input = top_layer multirnn_cell_fw = tf.nn.rnn_cell.MultiRNNCell([ rnn_cell(rnn_cell_dim=rnn_cell_dim, layer_type=rnn_type, dropout_keep_prob=dropout_keep_prob) for _ in range(num_rnn_layers) ]) if self.params['rnn_unidirectional']: top_layer, state = tf.nn.dynamic_rnn( cell=multirnn_cell_fw, inputs=rnn_input, sequence_length=src_length, dtype=rnn_input.dtype, time_major=False, ) else: multirnn_cell_bw = tf.nn.rnn_cell.MultiRNNCell([ rnn_cell(rnn_cell_dim=rnn_cell_dim, layer_type=rnn_type, dropout_keep_prob=dropout_keep_prob) for _ in range(num_rnn_layers) ]) top_layer, state = tf.nn.bidirectional_dynamic_rnn( cell_fw=multirnn_cell_fw, cell_bw=multirnn_cell_bw, inputs=rnn_input, sequence_length=src_length, dtype=rnn_input.dtype, time_major=False) # concat 2 tensors [B, T, n_cell_dim] --> [B, T, 2*n_cell_dim] top_layer = tf.concat(top_layer, 2) # -- end of rnn------------------------------------------------------------ if self.params['row_conv']: channels = top_layer.get_shape().as_list()[-1] top_layer = row_conv( name="row_conv", input_layer=top_layer, batch=batch_size, channels=channels, activation_fn=self.params['activation_fn'], width=self.params['row_conv_width'], regularizer=regularizer, training=training, data_format=data_format, bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, ) # Reshape [B, T, C] --> [B*T, C] c = top_layer.get_shape().as_list()[-1] top_layer = tf.reshape(top_layer, [-1, c]) # --- hidden layer with clipped ReLU activation and dropout--------------- top_layer = tf.layers.dense( inputs=top_layer, units=self.params['n_hidden'], kernel_regularizer=regularizer, activation=self.params['activation_fn'], name='fully_connected', ) outputs = tf.nn.dropout(x=top_layer, keep_prob=dropout_keep_prob) # reshape from [B*T,A] --> [B, T, A]. # Output shape: [batch_size, n_steps, n_hidden] outputs = tf.reshape( outputs, [batch_size, -1, self.params['n_hidden']], ) return { 'outputs': outputs, 'src_length': src_length, }
def _embed_style(self, style_spec, style_len): """ Code that implements the reference encoder as described in "Towards end-to-end prosody transfer for expressive speech synthesis with Tacotron", and "Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis" Config parameters: * **conv_layers** (list) --- See the conv_layers parameter for the Tacotron-2 model. * **num_rnn_layers** (int) --- Number of rnn layers in the reference encoder * **rnn_cell_dim** (int) --- Size of rnn layer * **rnn_unidirectional** (bool) --- Uni- or bi-directional rnn. * **rnn_type** --- Must be a valid tf rnn cell class * **emb_size** (int) --- Size of gst * **attention_layer_size** (int) --- Size of linear layers in attention * **num_tokens** (int) --- Number of tokens for gst * **num_heads** (int) --- Number of attention heads """ training = (self._mode == "train") regularizer = self.params.get('regularizer', None) data_format = self.params.get('data_format', 'channels_last') batch_size = style_spec.get_shape().as_list()[0] top_layer = tf.expand_dims(style_spec, -1) params = self.params['style_embedding_params'] if "conv_layers" in params: for i, conv_params in enumerate(params['conv_layers']): ch_out = conv_params['num_channels'] kernel_size = conv_params['kernel_size'] # [time, freq] strides = conv_params['stride'] padding = conv_params['padding'] if padding == "VALID": style_len = (style_len - kernel_size[0] + strides[0]) // strides[0] else: style_len = (style_len + strides[0] - 1) // strides[0] top_layer = conv_bn_actv( layer_type="conv2d", name="conv{}".format(i + 1), inputs=top_layer, filters=ch_out, kernel_size=kernel_size, activation_fn=self.params['activation_fn'], strides=strides, padding=padding, regularizer=regularizer, training=training, data_format=data_format, bn_momentum=self.params.get('bn_momentum', 0.1), bn_epsilon=self.params.get('bn_epsilon', 1e-5), ) if data_format == 'channels_first': top_layer = tf.transpose(top_layer, [0, 2, 1]) top_layer = tf.concat(tf.unstack(top_layer, axis=2), axis=-1) num_rnn_layers = params['num_rnn_layers'] if num_rnn_layers > 0: cell_params = {} cell_params["num_units"] = params['rnn_cell_dim'] rnn_type = params['rnn_type'] rnn_input = top_layer rnn_vars = [] multirnn_cell_fw = tf.nn.rnn_cell.MultiRNNCell([ single_cell(cell_class=rnn_type, cell_params=cell_params, training=training, residual_connections=False) for _ in range(num_rnn_layers) ]) rnn_vars += multirnn_cell_fw.trainable_variables if params['rnn_unidirectional']: top_layer, final_state = tf.nn.dynamic_rnn( cell=multirnn_cell_fw, inputs=rnn_input, sequence_length=style_len, dtype=rnn_input.dtype, time_major=False, ) final_state = final_state[0] else: multirnn_cell_bw = tf.nn.rnn_cell.MultiRNNCell([ single_cell(cell_class=rnn_type, cell_params=cell_params, training=training, residual_connections=False) for _ in range(num_rnn_layers) ]) top_layer, final_state = tf.nn.bidirectional_dynamic_rnn( cell_fw=multirnn_cell_fw, cell_bw=multirnn_cell_bw, inputs=rnn_input, sequence_length=style_len, dtype=rnn_input.dtype, time_major=False) # concat 2 tensors [B, T, n_cell_dim] --> [B, T, 2*n_cell_dim] final_state = tf.concat( (final_state[0][0].h, final_state[1][0].h), 1) rnn_vars += multirnn_cell_bw.trainable_variables top_layer = final_state # Apply linear layer top_layer = tf.layers.dense(top_layer, 128, activation=tf.nn.tanh, kernel_regularizer=regularizer, name="reference_activation") if regularizer and training: cell_weights = rnn_vars for weights in cell_weights: if "bias" not in weights.name: # print("Added regularizer to {}".format(weights.name)) if weights.dtype.base_dtype == tf.float16: tf.add_to_collection('REGULARIZATION_FUNCTIONS', (weights, regularizer)) else: tf.add_to_collection( ops.GraphKeys.REGULARIZATION_LOSSES, regularizer(weights)) num_units = params["num_tokens"] att_size = params["attention_layer_size"] # Randomly initilized tokens gst_embedding = tf.get_variable( "token_embeddings", shape=[num_units, params["emb_size"]], dtype=self.params["dtype"], initializer=tf.random_uniform_initializer( minval=-1., maxval=1., dtype=self.params["dtype"]), trainable=False) attention = attention_layer.Attention(params["attention_layer_size"], params["num_heads"], 0., training, mode="bahdanau") top_layer = tf.expand_dims(top_layer, 1) gst_embedding = tf.nn.tanh(gst_embedding) gst_embedding = tf.expand_dims(gst_embedding, 0) gst_embedding = tf.tile(gst_embedding, [batch_size, 1, 1]) token_embeddings = attention(top_layer, gst_embedding, None) token_embeddings = tf.squeeze(token_embeddings, 1) return token_embeddings