def _decode(self, input_dict): """ Decodes representation into data Args: input_dict (dict): Python dictionary with inputs to decoder. Must define: * src_inputs - decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * src_lengths - decoder input lengths Tensor of shape [batch_size] * tgt_inputs - Only during training. labels Tensor of the shape [batch_size, time, num_features] or [time, batch_size, num_features] * stop_token_inputs - Only during training. labels Tensor of the shape [batch_size, time, 1] or [time, batch_size, 1] * tgt_lengths - Only during training. labels lengths Tensor of the shape [batch_size] Returns: dict: A python dictionary containing: * outputs - array containing: * decoder_output - tensor of shape [batch_size, time, num_features] or [time, batch_size, num_features]. Spectrogram representation learned by the decoder rnn * spectrogram_prediction - tensor of shape [batch_size, time, num_features] or [time, batch_size, num_features]. Spectrogram containing the residual corrections from the postnet if enabled * alignments - tensor of shape [batch_size, time, memory_size] or [time, batch_size, memory_size]. The alignments learned by the attention layer * stop_token_prediction - tensor of shape [batch_size, time, 1] or [time, batch_size, 1]. The stop token predictions * final_sequence_lengths - tensor of shape [batch_size] * stop_token_predictions - tensor of shape [batch_size, time, 1] or [time, batch_size, 1]. The stop token predictions for use inside the loss function. """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_length'] if self._mode == "train" or (self._mode == "infer" and self._gta_forcing == True): spec = input_dict['target_tensors'][0] spec_length = input_dict['target_tensors'][2] else: spec = None spec_length = None _batch_size = encoder_outputs.get_shape().as_list()[0] training = (self._mode == "train") regularizer = self.params.get('regularizer', None) if self.params.get('enable_postnet', True): if "postnet_conv_layers" not in self.params: raise ValueError( "postnet_conv_layers must be passed from config file if postnet is" "enabled") if self._both: num_audio_features = self._n_feats["mel"] if self._mode == "train": spec, _ = tf.split( spec, [self._n_feats['mel'], self._n_feats['magnitude']], axis=2) else: num_audio_features = self._n_feats output_projection_layer = tf.layers.Dense( name="output_proj", units=num_audio_features, use_bias=True, ) stop_token_projection_layer = tf.layers.Dense( name="stop_token_proj", units=1, use_bias=True, ) prenet = None if self.params.get('enable_prenet', True): prenet = Prenet(self.params.get('prenet_units', 256), self.params.get('prenet_layers', 2), self.params.get("prenet_activation", tf.nn.relu), self.params["dtype"]) cell_params = {} cell_params["num_units"] = self.params['decoder_cell_units'] decoder_cells = [ single_cell( cell_class=self.params['decoder_cell_type'], cell_params=cell_params, zoneout_prob=self.params.get("zoneout_prob", 0.), dp_output_keep_prob=1. - self.params.get("dropout_prob", 0.1), training=training, ) for _ in range(self.params['decoder_layers']) ] if self.params['attention_type'] is not None: attention_mechanism = self._build_attention( encoder_outputs, enc_src_lengths, self.params.get("attention_bias", False)) attention_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells) attentive_cell = AttentionWrapper( cell=attention_cell, attention_mechanism=attention_mechanism, alignment_history=True, output_attention="both", ) decoder_cell = attentive_cell if self.params['attention_type'] is None: decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells) if self._mode == "train": train_and_not_sampling = True helper = TacotronTrainingHelper( inputs=spec, sequence_length=spec_length, prenet=None, model_dtype=self.params["dtype"], mask_decoder_sequence=self.params.get("mask_decoder_sequence", True)) elif self._mode == "eval" or self._mode == "infer": train_and_not_sampling = False inputs = tf.zeros((_batch_size, 1, num_audio_features), dtype=self.params["dtype"]) helper = TacotronHelper( inputs=inputs, prenet=None, mask_decoder_sequence=self.params.get("mask_decoder_sequence", True), gta_mels=spec, gta_mel_lengths=spec_length, ) else: raise ValueError("Unknown mode for decoder: {}".format(self._mode)) decoder = TacotronDecoder( decoder_cell=decoder_cell, helper=helper, initial_decoder_state=decoder_cell.zero_state( _batch_size, self.params["dtype"]), attention_type=self.params["attention_type"], spec_layer=output_projection_layer, stop_token_layer=stop_token_projection_layer, prenet=prenet, dtype=self.params["dtype"], train=train_and_not_sampling) if self._mode == 'train': maximum_iterations = tf.reduce_max(spec_length) else: maximum_iterations = tf.reduce_max(enc_src_lengths) * 10 outputs, final_state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode( # outputs, final_state, sequence_lengths, final_inputs = dynamic_decode( decoder=decoder, impute_finished=False, maximum_iterations=maximum_iterations, swap_memory=self.params.get("use_swap_memory", False), output_time_major=self.params.get("time_major", False), parallel_iterations=self.params.get("parallel_iterations", 32)) decoder_output = outputs.rnn_output stop_token_logits = outputs.stop_token_output with tf.variable_scope("decoder"): # If we are in train and doing sampling, we need to do the projections if train_and_not_sampling: decoder_spec_output = output_projection_layer(decoder_output) stop_token_logits = stop_token_projection_layer( decoder_spec_output) decoder_output = decoder_spec_output ## Add the post net ## if self.params.get('enable_postnet', True): dropout_keep_prob = self.params.get('postnet_keep_dropout_prob', 0.5) top_layer = decoder_output for i, conv_params in enumerate( self.params['postnet_conv_layers']): ch_out = conv_params['num_channels'] kernel_size = conv_params['kernel_size'] # [time, freq] strides = conv_params['stride'] padding = conv_params['padding'] activation_fn = conv_params['activation_fn'] if ch_out == -1: if self._both: ch_out = self._n_feats["mel"] else: ch_out = self._n_feats top_layer = conv_bn_actv( layer_type="conv1d", name="conv{}".format(i + 1), inputs=top_layer, filters=ch_out, kernel_size=kernel_size, activation_fn=activation_fn, strides=strides, padding=padding, regularizer=regularizer, training=training, data_format=self.params.get('postnet_data_format', 'channels_last'), bn_momentum=self.params.get('postnet_bn_momentum', 0.1), bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5), ) top_layer = tf.layers.dropout(top_layer, rate=1. - dropout_keep_prob, training=training) else: top_layer = tf.zeros([ _batch_size, maximum_iterations, outputs.rnn_output.get_shape()[-1] ], dtype=self.params["dtype"]) if regularizer and training: vars_to_regularize = [] vars_to_regularize += attentive_cell.trainable_variables if (attention_mechanism.memory_layer is not None): vars_to_regularize += attention_mechanism.memory_layer.trainable_variables vars_to_regularize += output_projection_layer.trainable_variables vars_to_regularize += stop_token_projection_layer.trainable_variables for weights in vars_to_regularize: if "bias" not in weights.name: # print("Added regularizer to {}".format(weights.name)) if weights.dtype.base_dtype == tf.float16: tf.add_to_collection('REGULARIZATION_FUNCTIONS', (weights, regularizer)) else: tf.add_to_collection( ops.GraphKeys.REGULARIZATION_LOSSES, regularizer(weights)) if self.params.get('enable_prenet', True): prenet.add_regularization(regularizer) if self.params['attention_type'] is not None: alignments = tf.transpose(final_state.alignment_history.stack(), [1, 0, 2]) else: alignments = tf.zeros([_batch_size, _batch_size, _batch_size]) spectrogram_prediction = decoder_output + top_layer if self._both: mag_spec_prediction = spectrogram_prediction mag_spec_prediction = conv_bn_actv( layer_type="conv1d", name="conv_0", inputs=mag_spec_prediction, filters=256, kernel_size=4, activation_fn=tf.nn.relu, strides=1, padding="SAME", regularizer=regularizer, training=training, data_format=self.params.get('postnet_data_format', 'channels_last'), bn_momentum=self.params.get('postnet_bn_momentum', 0.1), bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5), ) mag_spec_prediction = conv_bn_actv( layer_type="conv1d", name="conv_1", inputs=mag_spec_prediction, filters=512, kernel_size=4, activation_fn=tf.nn.relu, strides=1, padding="SAME", regularizer=regularizer, training=training, data_format=self.params.get('postnet_data_format', 'channels_last'), bn_momentum=self.params.get('postnet_bn_momentum', 0.1), bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5), ) if self._model.get_data_layer()._exp_mag: mag_spec_prediction = tf.exp(mag_spec_prediction) mag_spec_prediction = tf.layers.conv1d( mag_spec_prediction, self._n_feats["magnitude"], 1, name="post_net_proj", use_bias=False, ) else: mag_spec_prediction = tf.zeros( [_batch_size, _batch_size, _batch_size]) stop_token_prediction = tf.sigmoid(stop_token_logits) outputs = [ decoder_output, spectrogram_prediction, alignments, stop_token_prediction, sequence_lengths, mag_spec_prediction ] return { 'outputs': outputs, 'stop_token_prediction': stop_token_logits, }
def _decode(self, input_dict): """ Decodes representation into data :param input_dict: Python dictionary with inputs to decoder Must define: * src_inputs - decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * src_lengths - decoder input lengths Tensor of shape [batch_size] Does not need tgt_inputs and tgt_lengths :return: a Python dictionary with: * final_outputs - tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * final_state - tensor with decoder final state * final_sequence_lengths - tensor of shape [batch_size, time] or [time, batch_size] """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_lengths'] self._dec_emb_w = tf.get_variable( name='DecoderEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32) self._output_projection_layer = tf.layers.Dense( self._tgt_vocab_size, use_bias=False, ) cell_params = copy.deepcopy(self.params) cell_params["num_units"] = self.params['decoder_cell_units'] if self._mode == "train": dp_input_keep_prob = self.params['decoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['decoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 if self.params['attention_type'].startswith('gnmt'): residual_connections = False wrap_to_multi_rnn = False else: residual_connections = self.params['decoder_use_skip_connections'] wrap_to_multi_rnn = True self._decoder_cells = create_rnn_cell( cell_type=self.params['decoder_cell_type'], cell_params=cell_params, num_layers=self.params['decoder_layers'], dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, residual_connections=residual_connections, wrap_to_multi_rnn=wrap_to_multi_rnn, ) tiled_enc_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=self._beam_width, ) tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch( enc_src_lengths, multiplier=self._beam_width, ) attention_mechanism = self._build_attention( tiled_enc_outputs, tiled_enc_src_lengths, ) if self.params['attention_type'].startswith('gnmt'): attention_cell = self._decoder_cells.pop(0) attention_cell = AttentionWrapper( attention_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, # don't use attention layer. output_attention=False, name="gnmt_attention") attentive_decoder_cell = GNMTAttentionMultiCell( attention_cell, self._add_residual_wrapper(self._decoder_cells), use_new_attention=(self.params['attention_type'] == 'gnmt_v2')) else: attentive_decoder_cell = AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanism, ) batch_size_tensor = tf.constant(self._batch_size) embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup( self._dec_emb_w, ids), dtype=self.params['dtype']) #decoder = tf.contrib.seq2seq.BeamSearchDecoder( decoder = BeamSearchDecoder( cell=attentive_decoder_cell, embedding=embedding_fn, start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]), end_token=self.END_SYMBOL, initial_state=attentive_decoder_cell.zero_state( dtype=encoder_outputs.dtype, batch_size=batch_size_tensor * self._beam_width, ), beam_width=self._beam_width, output_layer=self._output_projection_layer, length_penalty_weight=self._length_penalty_weight) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) final_outputs, final_state, final_sequence_lengths = \ tf.contrib.seq2seq.dynamic_decode( decoder=decoder, maximum_iterations=tf.reduce_max(enc_src_lengths) * 2, swap_memory=use_swap_memory, output_time_major=time_major, ) return { 'logits': final_outputs.predicted_ids[:, :, 0], 'samples': [final_outputs.predicted_ids[:, :, 0]], 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths }
def _decode(self, input_dict): """ Decodes representation into data :param input_dict: Python dictionary with inputs to decoder Must define: * src_inputs - decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * src_lengths - decoder input lengths Tensor of shape [batch_size] * tgt_inputs - Only during training. labels Tensor of the shape [batch_size, time] or [time, batch_size] * tgt_lengths - Only during training. labels lengths Tensor of the shape [batch_size] :return: a Python dictionary with: * final_outputs - tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * final_state - tensor with decoder final state * final_sequence_lengths - tensor of shape [batch_size, time] or [time, batch_size] """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_lengths'] tgt_inputs = input_dict['target_tensors'][0] if 'target_tensors' in \ input_dict else None tgt_lengths = input_dict['target_tensors'][1] if 'target_tensors' in \ input_dict else None self._dec_emb_w = tf.get_variable( name='DecoderEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32, ) self._output_projection_layer = tf.layers.Dense( self._tgt_vocab_size, use_bias=False, ) cell_params = copy.deepcopy(self.params) cell_params["num_units"] = self.params['decoder_cell_units'] if self._mode == "train": dp_input_keep_prob = self.params['decoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['decoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 if self.params['attention_type'].startswith('gnmt'): residual_connections = False wrap_to_multi_rnn = False else: residual_connections = self.params['decoder_use_skip_connections'] wrap_to_multi_rnn = True self._decoder_cells = create_rnn_cell( cell_type=self.params['decoder_cell_type'], cell_params=cell_params, num_layers=self.params['decoder_layers'], dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, residual_connections=residual_connections, wrap_to_multi_rnn=wrap_to_multi_rnn, ) attention_mechanism = self._build_attention( encoder_outputs, enc_src_lengths, ) if self.params['attention_type'].startswith('gnmt'): attention_cell = self._decoder_cells.pop(0) # attention_cell = tf.contrib.seq2seq.AttentionWrapper( attention_cell = AttentionWrapper( attention_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, output_attention=False, name="gnmt_attention") attentive_decoder_cell = GNMTAttentionMultiCell( attention_cell, self._add_residual_wrapper(self._decoder_cells), use_new_attention=(self.params['attention_type'] == 'gnmt_v2')) else: # attentive_decoder_cell = tf.contrib.seq2seq.AttentionWrapper( attentive_decoder_cell = AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanism, ) if self._mode == "train": input_vectors = tf.cast(tf.nn.embedding_lookup( self._dec_emb_w, tgt_inputs), dtype=self.params['dtype']) helper = tf.contrib.seq2seq.TrainingHelper( inputs=input_vectors, sequence_length=tgt_lengths) decoder = tf.contrib.seq2seq.BasicDecoder( cell=attentive_decoder_cell, helper=helper, output_layer=self._output_projection_layer, initial_state=attentive_decoder_cell.zero_state( self._batch_size, dtype=encoder_outputs.dtype, ), ) elif self._mode == "infer" or self._mode == "eval": embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup( self._dec_emb_w, ids), dtype=self.params['dtype']) helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding=embedding_fn, #self._dec_emb_w, start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL), end_token=self.END_SYMBOL) decoder = tf.contrib.seq2seq.BasicDecoder( cell=attentive_decoder_cell, helper=helper, initial_state=attentive_decoder_cell.zero_state( batch_size=self._batch_size, dtype=encoder_outputs.dtype, ), output_layer=self._output_projection_layer, ) else: raise ValueError("Unknown mode for decoder: {}".format(self._mode)) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) if self._mode == 'train': maximum_iterations = tf.reduce_max(tgt_lengths) else: maximum_iterations = tf.reduce_max(enc_src_lengths) * 2 final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, # impute_finished=False if self._decoder_type == "beam_search" else True, impute_finished=True, maximum_iterations=maximum_iterations, swap_memory=use_swap_memory, output_time_major=time_major, ) return { 'logits': final_outputs.rnn_output, 'samples': [tf.argmax(final_outputs.rnn_output, axis=-1)], 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths }
def _decode(self, input_dict): """Decodes representation into data. Args: input_dict (dict): Python dictionary with inputs to decoder. Config parameters: * **src_inputs** --- Decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim]. * **src_lengths** --- Decoder input lengths Tensor of shape [batch_size] * **tgt_inputs** --- Only during training. labels Tensor of the shape [batch_size, time] or [time, batch_size]. * **tgt_lengths** --- Only during training. labels lengths Tensor of the shape [batch_size]. Returns: dict: Python dictionary with: * outputs - [predictions, alignments, enc_src_lengths]. predictions are the final predictions of the model. tensor of shape [batch_size, time]. alignments are the attention probabilities if attention is used. None if 'plot_attention' in attention_params is set to False. enc_src_lengths are the lengths of the input. tensor of shape [batch_size]. * logits - logits with the shape=[batch_size, output_dim]. * tgt_length - tensor of shape [batch_size] indicating the predicted sequence lengths. """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_length'] self._batch_size = int(encoder_outputs.get_shape()[0]) self._beam_width = self.params.get("beam_width", 1) tgt_inputs = None tgt_lengths = None if 'target_tensors' in input_dict: tgt_inputs = input_dict['target_tensors'][0] tgt_lengths = input_dict['target_tensors'][1] tgt_inputs = tf.concat([ tf.fill([self._batch_size, 1], self.GO_SYMBOL), tgt_inputs[:, :-1] ], -1) layer_type = self.params['rnn_type'] num_layers = self.params['num_layers'] attention_params = self.params['attention_params'] hidden_dim = self.params['hidden_dim'] dropout_keep_prob = self.params.get( 'dropout_keep_prob', 1.0) if self._mode == "train" else 1.0 # To-Do Seperate encoder and decoder position embeddings use_positional_embedding = self.params.get("pos_embedding", False) use_language_model = self.params.get("use_language_model", False) use_beam_search_decoder = (self._beam_width != 1) and (self._mode == "infer") self._target_emb_layer = tf.get_variable( name='TargetEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32, ) if use_positional_embedding: self.enc_pos_emb_size = int(encoder_outputs.get_shape()[-1]) self.enc_pos_emb_layer = tf.get_variable( name='EncoderPositionEmbeddingMatrix', shape=[1024, self.enc_pos_emb_size], dtype=tf.float32, ) encoder_output_positions = tf.range(0, tf.shape(encoder_outputs)[1], delta=1, dtype=tf.int32, name='positional_inputs') encoder_position_embeddings = tf.cast(tf.nn.embedding_lookup( self.enc_pos_emb_layer, encoder_output_positions), dtype=encoder_outputs.dtype) encoder_outputs += encoder_position_embeddings self.dec_pos_emb_size = self._tgt_emb_size self.dec_pos_emb_layer = tf.get_variable( name='DecoderPositionEmbeddingMatrix', shape=[1024, self.dec_pos_emb_size], dtype=tf.float32, ) output_projection_layer = FullyConnected( [self._tgt_vocab_size], dropout_keep_prob=dropout_keep_prob, mode=self._mode, ) rnn_cell = cells_dict[layer_type] dropout = tf.nn.rnn_cell.DropoutWrapper multirnn_cell = tf.nn.rnn_cell.MultiRNNCell([ dropout(rnn_cell(hidden_dim), output_keep_prob=dropout_keep_prob) for _ in range(num_layers) ]) if use_beam_search_decoder: encoder_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=self._beam_width, ) enc_src_lengths = tf.contrib.seq2seq.tile_batch( enc_src_lengths, multiplier=self._beam_width, ) attention_dim = attention_params["attention_dim"] attention_type = attention_params["attention_type"] num_heads = attention_params["num_heads"] plot_attention = attention_params["plot_attention"] if plot_attention: if use_beam_search_decoder: plot_attention = False print( "Plotting Attention is disabled for Beam Search Decoding") if num_heads != 1: plot_attention = False print( "Plotting Attention is disabled for Multi Head Attention") if self.params['dtype'] != tf.float32: plot_attention = False print( "Plotting Attention is disabled for Mixed Precision Mode") attention_params_dict = {} if attention_type == "bahadanu": AttentionMechanism = BahdanauAttention attention_params_dict["normalize"] = False, elif attention_type == "chorowski": AttentionMechanism = LocationSensitiveAttention attention_params_dict["use_coverage"] = attention_params[ "use_coverage"] attention_params_dict["location_attn_type"] = attention_type attention_params_dict["location_attention_params"] = { 'filters': 10, 'kernel_size': 101 } elif attention_type == "zhaopeng": AttentionMechanism = LocationSensitiveAttention attention_params_dict["use_coverage"] = attention_params[ "use_coverage"] attention_params_dict["query_dim"] = hidden_dim attention_params_dict["location_attn_type"] = attention_type attention_mechanism = [] for head in range(num_heads): attention_mechanism.append( AttentionMechanism(num_units=attention_dim, memory=encoder_outputs, memory_sequence_length=enc_src_lengths, probability_fn=tf.nn.softmax, dtype=tf.get_variable_scope().dtype, **attention_params_dict)) multirnn_cell_with_attention = AttentionWrapper( cell=multirnn_cell, attention_mechanism=attention_mechanism, attention_layer_size=[hidden_dim for i in range(num_heads)], output_attention=True, alignment_history=plot_attention, ) if self._mode == "train": decoder_output_positions = tf.range(0, tf.shape(tgt_inputs)[1], delta=1, dtype=tf.int32, name='positional_inputs') tgt_input_vectors = tf.nn.embedding_lookup(self._target_emb_layer, tgt_inputs) if use_positional_embedding: tgt_input_vectors += tf.nn.embedding_lookup( self.dec_pos_emb_layer, decoder_output_positions) tgt_input_vectors = tf.cast( tgt_input_vectors, dtype=self.params['dtype'], ) # helper = tf.contrib.seq2seq.TrainingHelper( helper = TrainingHelper( inputs=tgt_input_vectors, sequence_length=tgt_lengths, ) elif self._mode == "infer" or self._mode == "eval": embedding_fn = lambda ids: tf.cast( tf.nn.embedding_lookup(self._target_emb_layer, ids), dtype=self.params['dtype'], ) pos_embedding_fn = None if use_positional_embedding: pos_embedding_fn = lambda ids: tf.cast( tf.nn.embedding_lookup(self.dec_pos_emb_layer, ids), dtype=self.params['dtype'], ) # helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( helper = GreedyEmbeddingHelper( embedding=embedding_fn, start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL), end_token=self.END_SYMBOL, positional_embedding=pos_embedding_fn) if self._mode != "infer": maximum_iterations = tf.reduce_max(tgt_lengths) else: maximum_iterations = tf.reduce_max(enc_src_lengths) if not use_beam_search_decoder: decoder = tf.contrib.seq2seq.BasicDecoder( cell=multirnn_cell_with_attention, helper=helper, initial_state=multirnn_cell_with_attention.zero_state( batch_size=self._batch_size, dtype=encoder_outputs.dtype, ), output_layer=output_projection_layer, ) else: batch_size_tensor = tf.constant(self._batch_size) decoder = BeamSearchDecoder( cell=multirnn_cell_with_attention, embedding=embedding_fn, start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]), end_token=self.END_SYMBOL, initial_state=multirnn_cell_with_attention.zero_state( dtype=encoder_outputs.dtype, batch_size=batch_size_tensor * self._beam_width, ), beam_width=self._beam_width, output_layer=output_projection_layer, length_penalty_weight=0.0, ) final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, impute_finished=self.mode != "infer", maximum_iterations=maximum_iterations, ) if plot_attention: alignments = tf.transpose(final_state.alignment_history[0].stack(), [1, 0, 2]) else: alignments = None if not use_beam_search_decoder: outputs = tf.argmax(final_outputs.rnn_output, axis=-1) logits = final_outputs.rnn_output return_outputs = [outputs, alignments, enc_src_lengths] else: outputs = final_outputs.predicted_ids[:, :, 0] logits = final_outputs.predicted_ids[:, :, 0] return_outputs = [outputs, enc_src_lengths] if self.mode == "eval": max_len = tf.reduce_max(tgt_lengths) logits = tf.while_loop( lambda logits: max_len > tf.shape(logits)[1], lambda logits: tf.concat([ logits, tf.fill([tf.shape(logits)[0], 1, tf.shape(logits)[2]], tf.cast(1.0, self.params['dtype'])) ], 1), loop_vars=[logits], back_prop=False, ) return { 'outputs': return_outputs, 'logits': logits, 'tgt_length': final_sequence_lengths, }
def _decode(self, input_dict): """Decodes representation into data. Args: input_dict (dict): Python dictionary with inputs to decoder Must define: * src_inputs - decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * src_lengths - decoder input lengths Tensor of shape [batch_size] Does not need tgt_inputs and tgt_lengths Returns: dict: a Python dictionary with: * final_outputs - tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * final_state - tensor with decoder final state * final_sequence_lengths - tensor of shape [batch_size, time] or [time, batch_size] """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_lengths'] self._dec_emb_w = tf.get_variable( name='DecoderEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32 ) self._output_projection_layer = tf.layers.Dense( self._tgt_vocab_size, use_bias=False, ) if self._mode == "train": dp_input_keep_prob = self.params['decoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['decoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 residual_connections = self.params['decoder_use_skip_connections'] # list of cells self._decoder_cells = [ single_cell( cell_class=self.params['core_cell'], cell_params=self.params.get('core_cell_params', {}), dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, # residual connections are added a little differently for GNMT residual_connections=False if self.params['attention_type'].startswith('gnmt') else residual_connections, ) for _ in range(self.params['decoder_layers']) ] # pylint: disable=no-member tiled_enc_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=self._beam_width, ) # pylint: disable=no-member tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch( enc_src_lengths, multiplier=self._beam_width, ) attention_mechanism = self._build_attention( tiled_enc_outputs, tiled_enc_src_lengths, ) if self.params['attention_type'].startswith('gnmt'): attention_cell = self._decoder_cells.pop(0) attention_cell = AttentionWrapper( attention_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, # don't use attention layer. output_attention=False, name="gnmt_attention", ) attentive_decoder_cell = GNMTAttentionMultiCell( attention_cell, self._add_residual_wrapper(self._decoder_cells) if residual_connections else self._decoder_cells, use_new_attention=(self.params['attention_type'] == 'gnmt_v2') ) else: # non-GNMT attentive_decoder_cell = AttentionWrapper( # pylint: disable=no-member cell=tf.contrib.rnn.MultiRNNCell(self._decoder_cells), attention_mechanism=attention_mechanism, ) batch_size_tensor = tf.constant(self._batch_size) embedding_fn = lambda ids: tf.cast( tf.nn.embedding_lookup(self._dec_emb_w, ids), dtype=self.params['dtype'], ) decoder = BeamSearchDecoder( cell=attentive_decoder_cell, embedding=embedding_fn, start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]), end_token=self.END_SYMBOL, initial_state=attentive_decoder_cell.zero_state( dtype=encoder_outputs.dtype, batch_size=batch_size_tensor * self._beam_width, ), beam_width=self._beam_width, output_layer=self._output_projection_layer, length_penalty_weight=self._length_penalty_weight ) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) final_outputs, final_state, final_sequence_lengths = \ tf.contrib.seq2seq.dynamic_decode( # pylint: disable=no-member decoder=decoder, maximum_iterations=tf.reduce_max(enc_src_lengths) * 2, swap_memory=use_swap_memory, output_time_major=time_major, ) return {'logits': final_outputs.predicted_ids[:, :, 0] if not time_major else tf.transpose(final_outputs.predicted_ids[:, :, 0], perm=[1, 0, 2]), 'outputs': [final_outputs.predicted_ids[:, :, 0]], 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths}
def _decode(self, input_dict): """Decodes representation into data. Args: input_dict (dict): Python dictionary with inputs to decoder. Config parameters: * **src_inputs** --- Decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * **src_lengths** --- Decoder input lengths Tensor of shape [batch_size] * **tgt_inputs** --- Only during training. labels Tensor of the shape [batch_size, time] or [time, batch_size]. * **tgt_lengths** --- Only during training. labels lengths Tensor of the shape [batch_size]. Returns: dict: Python dictionary with: * final_outputs - tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * final_state - tensor with decoder final state * final_sequence_lengths - tensor of shape [batch_size, time] or [time, batch_size] """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_lengths'] tgt_inputs = input_dict['target_tensors'][0] if 'target_tensors' in \ input_dict else None tgt_lengths = input_dict['target_tensors'][1] if 'target_tensors' in \ input_dict else None self._dec_emb_w = tf.get_variable( name='DecoderEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32, ) self._output_projection_layer = tf.layers.Dense( self._tgt_vocab_size, use_bias=False, ) if self._mode == "train": dp_input_keep_prob = self.params['decoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['decoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 residual_connections = self.params['decoder_use_skip_connections'] # list of cells self._decoder_cells = [ single_cell( cell_class=self.params['core_cell'], cell_params=self.params.get('core_cell_params', {}), dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, # residual connections are added a little differently for GNMT residual_connections=False if self.params['attention_type'].startswith('gnmt') else residual_connections, ) for _ in range(self.params['decoder_layers']) ] attention_mechanism = self._build_attention( encoder_outputs, enc_src_lengths, ) if self.params['attention_type'].startswith('gnmt'): attention_cell = self._decoder_cells.pop(0) attention_cell = AttentionWrapper( attention_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, output_attention=False, name="gnmt_attention", ) attentive_decoder_cell = GNMTAttentionMultiCell( attention_cell, self._add_residual_wrapper(self._decoder_cells) if residual_connections else self._decoder_cells, use_new_attention=(self.params['attention_type'] == 'gnmt_v2'), ) else: attentive_decoder_cell = AttentionWrapper( # pylint: disable=no-member cell=tf.contrib.rnn.MultiRNNCell(self._decoder_cells), attention_mechanism=attention_mechanism, ) if self._mode == "train": input_vectors = tf.cast( tf.nn.embedding_lookup(self._dec_emb_w, tgt_inputs), dtype=self.params['dtype'], ) helper = tf.contrib.seq2seq.TrainingHelper( # pylint: disable=no-member inputs=input_vectors, sequence_length=tgt_lengths, ) decoder = tf.contrib.seq2seq.BasicDecoder( # pylint: disable=no-member cell=attentive_decoder_cell, helper=helper, output_layer=self._output_projection_layer, initial_state=attentive_decoder_cell.zero_state( self._batch_size, dtype=encoder_outputs.dtype, ), ) elif self._mode == "infer" or self._mode == "eval": embedding_fn = lambda ids: tf.cast( tf.nn.embedding_lookup(self._dec_emb_w, ids), dtype=self.params['dtype'], ) # pylint: disable=no-member helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding=embedding_fn, start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL), end_token=self.END_SYMBOL, ) decoder = tf.contrib.seq2seq.BasicDecoder( # pylint: disable=no-member cell=attentive_decoder_cell, helper=helper, initial_state=attentive_decoder_cell.zero_state( batch_size=self._batch_size, dtype=encoder_outputs.dtype, ), output_layer=self._output_projection_layer, ) else: raise ValueError( "Unknown mode for decoder: {}".format(self._mode) ) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) if self._mode == 'train': maximum_iterations = tf.reduce_max(tgt_lengths) else: maximum_iterations = tf.reduce_max(enc_src_lengths) * 2 # pylint: disable=no-member final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, impute_finished=True, maximum_iterations=maximum_iterations, swap_memory=use_swap_memory, output_time_major=time_major, ) return {'logits': final_outputs.rnn_output if not time_major else tf.transpose(final_outputs.rnn_output, perm=[1, 0, 2]), 'outputs': [tf.argmax(final_outputs.rnn_output, axis=-1)], 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths}