コード例 #1
0
    def _decode(self, input_dict):
        """
    Decodes representation into data

    Args:
      input_dict (dict): Python dictionary with inputs to decoder. Must define:
          * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
            or [time, batch_size, dim]
          * src_lengths - decoder input lengths Tensor of shape [batch_size]
          * tgt_inputs - Only during training. labels Tensor of the
            shape [batch_size, time, num_features] or
            [time, batch_size, num_features]
          * stop_token_inputs - Only during training. labels Tensor of the
            shape [batch_size, time, 1] or [time, batch_size, 1]
          * tgt_lengths - Only during training. labels lengths
            Tensor of the shape [batch_size]

    Returns:
      dict:
        A python dictionary containing:

          * outputs - array containing:

              * decoder_output - tensor of shape [batch_size, time,
                num_features] or [time, batch_size, num_features]. Spectrogram
                representation learned by the decoder rnn
              * spectrogram_prediction - tensor of shape [batch_size, time,
                num_features] or [time, batch_size, num_features]. Spectrogram
                containing the residual corrections from the postnet if enabled
              * alignments - tensor of shape [batch_size, time, memory_size]
                or [time, batch_size, memory_size]. The alignments learned by
                the attention layer
              * stop_token_prediction - tensor of shape [batch_size, time, 1]
                or [time, batch_size, 1]. The stop token predictions
              * final_sequence_lengths - tensor of shape [batch_size]
          * stop_token_predictions - tensor of shape [batch_size, time, 1]
            or [time, batch_size, 1]. The stop token predictions for use inside
            the loss function.
    """
        encoder_outputs = input_dict['encoder_output']['outputs']
        enc_src_lengths = input_dict['encoder_output']['src_length']
        if self._mode == "train" or (self._mode == "infer"
                                     and self._gta_forcing == True):
            spec = input_dict['target_tensors'][0]
            spec_length = input_dict['target_tensors'][2]
        else:
            spec = None
            spec_length = None

        _batch_size = encoder_outputs.get_shape().as_list()[0]

        training = (self._mode == "train")
        regularizer = self.params.get('regularizer', None)

        if self.params.get('enable_postnet', True):
            if "postnet_conv_layers" not in self.params:
                raise ValueError(
                    "postnet_conv_layers must be passed from config file if postnet is"
                    "enabled")

        if self._both:
            num_audio_features = self._n_feats["mel"]
            if self._mode == "train":
                spec, _ = tf.split(
                    spec, [self._n_feats['mel'], self._n_feats['magnitude']],
                    axis=2)
        else:
            num_audio_features = self._n_feats

        output_projection_layer = tf.layers.Dense(
            name="output_proj",
            units=num_audio_features,
            use_bias=True,
        )
        stop_token_projection_layer = tf.layers.Dense(
            name="stop_token_proj",
            units=1,
            use_bias=True,
        )

        prenet = None
        if self.params.get('enable_prenet', True):
            prenet = Prenet(self.params.get('prenet_units', 256),
                            self.params.get('prenet_layers', 2),
                            self.params.get("prenet_activation", tf.nn.relu),
                            self.params["dtype"])

        cell_params = {}
        cell_params["num_units"] = self.params['decoder_cell_units']
        decoder_cells = [
            single_cell(
                cell_class=self.params['decoder_cell_type'],
                cell_params=cell_params,
                zoneout_prob=self.params.get("zoneout_prob", 0.),
                dp_output_keep_prob=1. - self.params.get("dropout_prob", 0.1),
                training=training,
            ) for _ in range(self.params['decoder_layers'])
        ]

        if self.params['attention_type'] is not None:
            attention_mechanism = self._build_attention(
                encoder_outputs, enc_src_lengths,
                self.params.get("attention_bias", False))

            attention_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells)

            attentive_cell = AttentionWrapper(
                cell=attention_cell,
                attention_mechanism=attention_mechanism,
                alignment_history=True,
                output_attention="both",
            )

            decoder_cell = attentive_cell

        if self.params['attention_type'] is None:
            decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells)

        if self._mode == "train":
            train_and_not_sampling = True
            helper = TacotronTrainingHelper(
                inputs=spec,
                sequence_length=spec_length,
                prenet=None,
                model_dtype=self.params["dtype"],
                mask_decoder_sequence=self.params.get("mask_decoder_sequence",
                                                      True))
        elif self._mode == "eval" or self._mode == "infer":
            train_and_not_sampling = False
            inputs = tf.zeros((_batch_size, 1, num_audio_features),
                              dtype=self.params["dtype"])
            helper = TacotronHelper(
                inputs=inputs,
                prenet=None,
                mask_decoder_sequence=self.params.get("mask_decoder_sequence",
                                                      True),
                gta_mels=spec,
                gta_mel_lengths=spec_length,
            )
        else:
            raise ValueError("Unknown mode for decoder: {}".format(self._mode))
        decoder = TacotronDecoder(
            decoder_cell=decoder_cell,
            helper=helper,
            initial_decoder_state=decoder_cell.zero_state(
                _batch_size, self.params["dtype"]),
            attention_type=self.params["attention_type"],
            spec_layer=output_projection_layer,
            stop_token_layer=stop_token_projection_layer,
            prenet=prenet,
            dtype=self.params["dtype"],
            train=train_and_not_sampling)

        if self._mode == 'train':
            maximum_iterations = tf.reduce_max(spec_length)
        else:
            maximum_iterations = tf.reduce_max(enc_src_lengths) * 10

        outputs, final_state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
            # outputs, final_state, sequence_lengths, final_inputs = dynamic_decode(
            decoder=decoder,
            impute_finished=False,
            maximum_iterations=maximum_iterations,
            swap_memory=self.params.get("use_swap_memory", False),
            output_time_major=self.params.get("time_major", False),
            parallel_iterations=self.params.get("parallel_iterations", 32))

        decoder_output = outputs.rnn_output
        stop_token_logits = outputs.stop_token_output

        with tf.variable_scope("decoder"):
            # If we are in train and doing sampling, we need to do the projections
            if train_and_not_sampling:
                decoder_spec_output = output_projection_layer(decoder_output)
                stop_token_logits = stop_token_projection_layer(
                    decoder_spec_output)
                decoder_output = decoder_spec_output

        ## Add the post net ##
        if self.params.get('enable_postnet', True):
            dropout_keep_prob = self.params.get('postnet_keep_dropout_prob',
                                                0.5)

            top_layer = decoder_output
            for i, conv_params in enumerate(
                    self.params['postnet_conv_layers']):
                ch_out = conv_params['num_channels']
                kernel_size = conv_params['kernel_size']  # [time, freq]
                strides = conv_params['stride']
                padding = conv_params['padding']
                activation_fn = conv_params['activation_fn']

                if ch_out == -1:
                    if self._both:
                        ch_out = self._n_feats["mel"]
                    else:
                        ch_out = self._n_feats

                top_layer = conv_bn_actv(
                    layer_type="conv1d",
                    name="conv{}".format(i + 1),
                    inputs=top_layer,
                    filters=ch_out,
                    kernel_size=kernel_size,
                    activation_fn=activation_fn,
                    strides=strides,
                    padding=padding,
                    regularizer=regularizer,
                    training=training,
                    data_format=self.params.get('postnet_data_format',
                                                'channels_last'),
                    bn_momentum=self.params.get('postnet_bn_momentum', 0.1),
                    bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5),
                )
                top_layer = tf.layers.dropout(top_layer,
                                              rate=1. - dropout_keep_prob,
                                              training=training)

        else:
            top_layer = tf.zeros([
                _batch_size, maximum_iterations,
                outputs.rnn_output.get_shape()[-1]
            ],
                                 dtype=self.params["dtype"])

        if regularizer and training:
            vars_to_regularize = []
            vars_to_regularize += attentive_cell.trainable_variables
            if (attention_mechanism.memory_layer is not None):
                vars_to_regularize += attention_mechanism.memory_layer.trainable_variables
            vars_to_regularize += output_projection_layer.trainable_variables
            vars_to_regularize += stop_token_projection_layer.trainable_variables

            for weights in vars_to_regularize:
                if "bias" not in weights.name:
                    # print("Added regularizer to {}".format(weights.name))
                    if weights.dtype.base_dtype == tf.float16:
                        tf.add_to_collection('REGULARIZATION_FUNCTIONS',
                                             (weights, regularizer))
                    else:
                        tf.add_to_collection(
                            ops.GraphKeys.REGULARIZATION_LOSSES,
                            regularizer(weights))

            if self.params.get('enable_prenet', True):
                prenet.add_regularization(regularizer)

        if self.params['attention_type'] is not None:
            alignments = tf.transpose(final_state.alignment_history.stack(),
                                      [1, 0, 2])
        else:
            alignments = tf.zeros([_batch_size, _batch_size, _batch_size])

        spectrogram_prediction = decoder_output + top_layer
        if self._both:
            mag_spec_prediction = spectrogram_prediction
            mag_spec_prediction = conv_bn_actv(
                layer_type="conv1d",
                name="conv_0",
                inputs=mag_spec_prediction,
                filters=256,
                kernel_size=4,
                activation_fn=tf.nn.relu,
                strides=1,
                padding="SAME",
                regularizer=regularizer,
                training=training,
                data_format=self.params.get('postnet_data_format',
                                            'channels_last'),
                bn_momentum=self.params.get('postnet_bn_momentum', 0.1),
                bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5),
            )
            mag_spec_prediction = conv_bn_actv(
                layer_type="conv1d",
                name="conv_1",
                inputs=mag_spec_prediction,
                filters=512,
                kernel_size=4,
                activation_fn=tf.nn.relu,
                strides=1,
                padding="SAME",
                regularizer=regularizer,
                training=training,
                data_format=self.params.get('postnet_data_format',
                                            'channels_last'),
                bn_momentum=self.params.get('postnet_bn_momentum', 0.1),
                bn_epsilon=self.params.get('postnet_bn_epsilon', 1e-5),
            )
            if self._model.get_data_layer()._exp_mag:
                mag_spec_prediction = tf.exp(mag_spec_prediction)
            mag_spec_prediction = tf.layers.conv1d(
                mag_spec_prediction,
                self._n_feats["magnitude"],
                1,
                name="post_net_proj",
                use_bias=False,
            )
        else:
            mag_spec_prediction = tf.zeros(
                [_batch_size, _batch_size, _batch_size])

        stop_token_prediction = tf.sigmoid(stop_token_logits)
        outputs = [
            decoder_output, spectrogram_prediction, alignments,
            stop_token_prediction, sequence_lengths, mag_spec_prediction
        ]

        return {
            'outputs': outputs,
            'stop_token_prediction': stop_token_logits,
        }
コード例 #2
0
ファイル: rnn_decoders.py プロジェクト: vsl9/OpenSeq2Seq
    def _decode(self, input_dict):
        """
    Decodes representation into data
    :param input_dict: Python dictionary with inputs to decoder
    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
    Does not need tgt_inputs and tgt_lengths
    :return: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim] or
                        [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time] or
                                 [time, batch_size]
    """
        encoder_outputs = input_dict['encoder_output']['outputs']
        enc_src_lengths = input_dict['encoder_output']['src_lengths']

        self._dec_emb_w = tf.get_variable(
            name='DecoderEmbeddingMatrix',
            shape=[self._tgt_vocab_size, self._tgt_emb_size],
            dtype=tf.float32)

        self._output_projection_layer = tf.layers.Dense(
            self._tgt_vocab_size,
            use_bias=False,
        )

        cell_params = copy.deepcopy(self.params)
        cell_params["num_units"] = self.params['decoder_cell_units']

        if self._mode == "train":
            dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
            dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
        else:
            dp_input_keep_prob = 1.0
            dp_output_keep_prob = 1.0

        if self.params['attention_type'].startswith('gnmt'):
            residual_connections = False
            wrap_to_multi_rnn = False
        else:
            residual_connections = self.params['decoder_use_skip_connections']
            wrap_to_multi_rnn = True

        self._decoder_cells = create_rnn_cell(
            cell_type=self.params['decoder_cell_type'],
            cell_params=cell_params,
            num_layers=self.params['decoder_layers'],
            dp_input_keep_prob=dp_input_keep_prob,
            dp_output_keep_prob=dp_output_keep_prob,
            residual_connections=residual_connections,
            wrap_to_multi_rnn=wrap_to_multi_rnn,
        )

        tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(
            encoder_outputs,
            multiplier=self._beam_width,
        )
        tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(
            enc_src_lengths,
            multiplier=self._beam_width,
        )
        attention_mechanism = self._build_attention(
            tiled_enc_outputs,
            tiled_enc_src_lengths,
        )

        if self.params['attention_type'].startswith('gnmt'):
            attention_cell = self._decoder_cells.pop(0)
            attention_cell = AttentionWrapper(
                attention_cell,
                attention_mechanism=attention_mechanism,
                attention_layer_size=None,  # don't use attention layer.
                output_attention=False,
                name="gnmt_attention")
            attentive_decoder_cell = GNMTAttentionMultiCell(
                attention_cell,
                self._add_residual_wrapper(self._decoder_cells),
                use_new_attention=(self.params['attention_type'] == 'gnmt_v2'))
        else:
            attentive_decoder_cell = AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanism,
            )
        batch_size_tensor = tf.constant(self._batch_size)
        embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup(
            self._dec_emb_w, ids),
                                           dtype=self.params['dtype'])
        #decoder = tf.contrib.seq2seq.BeamSearchDecoder(
        decoder = BeamSearchDecoder(
            cell=attentive_decoder_cell,
            embedding=embedding_fn,
            start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]),
            end_token=self.END_SYMBOL,
            initial_state=attentive_decoder_cell.zero_state(
                dtype=encoder_outputs.dtype,
                batch_size=batch_size_tensor * self._beam_width,
            ),
            beam_width=self._beam_width,
            output_layer=self._output_projection_layer,
            length_penalty_weight=self._length_penalty_weight)

        time_major = self.params.get("time_major", False)
        use_swap_memory = self.params.get("use_swap_memory", False)
        final_outputs, final_state, final_sequence_lengths = \
          tf.contrib.seq2seq.dynamic_decode(
          decoder=decoder,
          maximum_iterations=tf.reduce_max(enc_src_lengths) * 2,
          swap_memory=use_swap_memory,
          output_time_major=time_major,
        )

        return {
            'logits': final_outputs.predicted_ids[:, :, 0],
            'samples': [final_outputs.predicted_ids[:, :, 0]],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths
        }
コード例 #3
0
ファイル: rnn_decoders.py プロジェクト: vsl9/OpenSeq2Seq
    def _decode(self, input_dict):
        """
    Decodes representation into data
    :param input_dict: Python dictionary with inputs to decoder
    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
      * tgt_inputs - Only during training. labels Tensor of the
                     shape [batch_size, time] or [time, batch_size]
      * tgt_lengths - Only during training. labels lengths
                      Tensor of the shape [batch_size]
    :return: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim]
                        or [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time]
                                 or [time, batch_size]
    """
        encoder_outputs = input_dict['encoder_output']['outputs']
        enc_src_lengths = input_dict['encoder_output']['src_lengths']
        tgt_inputs = input_dict['target_tensors'][0] if 'target_tensors' in \
                                                        input_dict else None
        tgt_lengths = input_dict['target_tensors'][1] if 'target_tensors' in \
                                                        input_dict else None

        self._dec_emb_w = tf.get_variable(
            name='DecoderEmbeddingMatrix',
            shape=[self._tgt_vocab_size, self._tgt_emb_size],
            dtype=tf.float32,
        )

        self._output_projection_layer = tf.layers.Dense(
            self._tgt_vocab_size,
            use_bias=False,
        )

        cell_params = copy.deepcopy(self.params)
        cell_params["num_units"] = self.params['decoder_cell_units']

        if self._mode == "train":
            dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
            dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
        else:
            dp_input_keep_prob = 1.0
            dp_output_keep_prob = 1.0

        if self.params['attention_type'].startswith('gnmt'):
            residual_connections = False
            wrap_to_multi_rnn = False
        else:
            residual_connections = self.params['decoder_use_skip_connections']
            wrap_to_multi_rnn = True

        self._decoder_cells = create_rnn_cell(
            cell_type=self.params['decoder_cell_type'],
            cell_params=cell_params,
            num_layers=self.params['decoder_layers'],
            dp_input_keep_prob=dp_input_keep_prob,
            dp_output_keep_prob=dp_output_keep_prob,
            residual_connections=residual_connections,
            wrap_to_multi_rnn=wrap_to_multi_rnn,
        )

        attention_mechanism = self._build_attention(
            encoder_outputs,
            enc_src_lengths,
        )
        if self.params['attention_type'].startswith('gnmt'):
            attention_cell = self._decoder_cells.pop(0)
            # attention_cell = tf.contrib.seq2seq.AttentionWrapper(
            attention_cell = AttentionWrapper(
                attention_cell,
                attention_mechanism=attention_mechanism,
                attention_layer_size=None,
                output_attention=False,
                name="gnmt_attention")
            attentive_decoder_cell = GNMTAttentionMultiCell(
                attention_cell,
                self._add_residual_wrapper(self._decoder_cells),
                use_new_attention=(self.params['attention_type'] == 'gnmt_v2'))
        else:
            # attentive_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            attentive_decoder_cell = AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanism,
            )
        if self._mode == "train":
            input_vectors = tf.cast(tf.nn.embedding_lookup(
                self._dec_emb_w, tgt_inputs),
                                    dtype=self.params['dtype'])
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=input_vectors, sequence_length=tgt_lengths)
            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=attentive_decoder_cell,
                helper=helper,
                output_layer=self._output_projection_layer,
                initial_state=attentive_decoder_cell.zero_state(
                    self._batch_size,
                    dtype=encoder_outputs.dtype,
                ),
            )
        elif self._mode == "infer" or self._mode == "eval":
            embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup(
                self._dec_emb_w, ids),
                                               dtype=self.params['dtype'])
            helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                embedding=embedding_fn,  #self._dec_emb_w,
                start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL),
                end_token=self.END_SYMBOL)
            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=attentive_decoder_cell,
                helper=helper,
                initial_state=attentive_decoder_cell.zero_state(
                    batch_size=self._batch_size,
                    dtype=encoder_outputs.dtype,
                ),
                output_layer=self._output_projection_layer,
            )
        else:
            raise ValueError("Unknown mode for decoder: {}".format(self._mode))

        time_major = self.params.get("time_major", False)
        use_swap_memory = self.params.get("use_swap_memory", False)
        if self._mode == 'train':
            maximum_iterations = tf.reduce_max(tgt_lengths)
        else:
            maximum_iterations = tf.reduce_max(enc_src_lengths) * 2

        final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
            decoder=decoder,
            # impute_finished=False if self._decoder_type == "beam_search" else True,
            impute_finished=True,
            maximum_iterations=maximum_iterations,
            swap_memory=use_swap_memory,
            output_time_major=time_major,
        )

        return {
            'logits': final_outputs.rnn_output,
            'samples': [tf.argmax(final_outputs.rnn_output, axis=-1)],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths
        }
コード例 #4
0
    def _decode(self, input_dict):
        """Decodes representation into data.

    Args:
      input_dict (dict): Python dictionary with inputs to decoder.


    Config parameters:

    * **src_inputs** --- Decoder input Tensor of shape [batch_size, time, dim]
      or [time, batch_size, dim].
    * **src_lengths** --- Decoder input lengths Tensor of shape [batch_size]
    * **tgt_inputs** --- Only during training. labels Tensor of the
      shape [batch_size, time] or [time, batch_size].
    * **tgt_lengths** --- Only during training. labels lengths
      Tensor of the shape [batch_size].

    Returns:
      dict: Python dictionary with:
      * outputs - [predictions, alignments, enc_src_lengths].
        predictions are the final predictions of the model. tensor of shape [batch_size, time].
        alignments are the attention probabilities if attention is used. None if 'plot_attention' in attention_params is set to False.
        enc_src_lengths are the lengths of the input. tensor of shape [batch_size].
      * logits - logits with the shape=[batch_size, output_dim].
      * tgt_length - tensor of shape [batch_size] indicating the predicted sequence lengths.
    """
        encoder_outputs = input_dict['encoder_output']['outputs']
        enc_src_lengths = input_dict['encoder_output']['src_length']

        self._batch_size = int(encoder_outputs.get_shape()[0])
        self._beam_width = self.params.get("beam_width", 1)

        tgt_inputs = None
        tgt_lengths = None
        if 'target_tensors' in input_dict:
            tgt_inputs = input_dict['target_tensors'][0]
            tgt_lengths = input_dict['target_tensors'][1]
            tgt_inputs = tf.concat([
                tf.fill([self._batch_size, 1], self.GO_SYMBOL),
                tgt_inputs[:, :-1]
            ], -1)

        layer_type = self.params['rnn_type']
        num_layers = self.params['num_layers']
        attention_params = self.params['attention_params']
        hidden_dim = self.params['hidden_dim']
        dropout_keep_prob = self.params.get(
            'dropout_keep_prob', 1.0) if self._mode == "train" else 1.0

        # To-Do Seperate encoder and decoder position embeddings
        use_positional_embedding = self.params.get("pos_embedding", False)
        use_language_model = self.params.get("use_language_model", False)
        use_beam_search_decoder = (self._beam_width != 1) and (self._mode
                                                               == "infer")

        self._target_emb_layer = tf.get_variable(
            name='TargetEmbeddingMatrix',
            shape=[self._tgt_vocab_size, self._tgt_emb_size],
            dtype=tf.float32,
        )

        if use_positional_embedding:
            self.enc_pos_emb_size = int(encoder_outputs.get_shape()[-1])
            self.enc_pos_emb_layer = tf.get_variable(
                name='EncoderPositionEmbeddingMatrix',
                shape=[1024, self.enc_pos_emb_size],
                dtype=tf.float32,
            )
            encoder_output_positions = tf.range(0,
                                                tf.shape(encoder_outputs)[1],
                                                delta=1,
                                                dtype=tf.int32,
                                                name='positional_inputs')
            encoder_position_embeddings = tf.cast(tf.nn.embedding_lookup(
                self.enc_pos_emb_layer, encoder_output_positions),
                                                  dtype=encoder_outputs.dtype)
            encoder_outputs += encoder_position_embeddings

            self.dec_pos_emb_size = self._tgt_emb_size
            self.dec_pos_emb_layer = tf.get_variable(
                name='DecoderPositionEmbeddingMatrix',
                shape=[1024, self.dec_pos_emb_size],
                dtype=tf.float32,
            )

        output_projection_layer = FullyConnected(
            [self._tgt_vocab_size],
            dropout_keep_prob=dropout_keep_prob,
            mode=self._mode,
        )

        rnn_cell = cells_dict[layer_type]

        dropout = tf.nn.rnn_cell.DropoutWrapper

        multirnn_cell = tf.nn.rnn_cell.MultiRNNCell([
            dropout(rnn_cell(hidden_dim), output_keep_prob=dropout_keep_prob)
            for _ in range(num_layers)
        ])

        if use_beam_search_decoder:
            encoder_outputs = tf.contrib.seq2seq.tile_batch(
                encoder_outputs,
                multiplier=self._beam_width,
            )
            enc_src_lengths = tf.contrib.seq2seq.tile_batch(
                enc_src_lengths,
                multiplier=self._beam_width,
            )

        attention_dim = attention_params["attention_dim"]
        attention_type = attention_params["attention_type"]
        num_heads = attention_params["num_heads"]
        plot_attention = attention_params["plot_attention"]
        if plot_attention:
            if use_beam_search_decoder:
                plot_attention = False
                print(
                    "Plotting Attention is disabled for Beam Search Decoding")
            if num_heads != 1:
                plot_attention = False
                print(
                    "Plotting Attention is disabled for Multi Head Attention")
            if self.params['dtype'] != tf.float32:
                plot_attention = False
                print(
                    "Plotting Attention is disabled for Mixed Precision Mode")

        attention_params_dict = {}
        if attention_type == "bahadanu":
            AttentionMechanism = BahdanauAttention
            attention_params_dict["normalize"] = False,
        elif attention_type == "chorowski":
            AttentionMechanism = LocationSensitiveAttention
            attention_params_dict["use_coverage"] = attention_params[
                "use_coverage"]
            attention_params_dict["location_attn_type"] = attention_type
            attention_params_dict["location_attention_params"] = {
                'filters': 10,
                'kernel_size': 101
            }
        elif attention_type == "zhaopeng":
            AttentionMechanism = LocationSensitiveAttention
            attention_params_dict["use_coverage"] = attention_params[
                "use_coverage"]
            attention_params_dict["query_dim"] = hidden_dim
            attention_params_dict["location_attn_type"] = attention_type

        attention_mechanism = []

        for head in range(num_heads):
            attention_mechanism.append(
                AttentionMechanism(num_units=attention_dim,
                                   memory=encoder_outputs,
                                   memory_sequence_length=enc_src_lengths,
                                   probability_fn=tf.nn.softmax,
                                   dtype=tf.get_variable_scope().dtype,
                                   **attention_params_dict))

        multirnn_cell_with_attention = AttentionWrapper(
            cell=multirnn_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=[hidden_dim for i in range(num_heads)],
            output_attention=True,
            alignment_history=plot_attention,
        )

        if self._mode == "train":
            decoder_output_positions = tf.range(0,
                                                tf.shape(tgt_inputs)[1],
                                                delta=1,
                                                dtype=tf.int32,
                                                name='positional_inputs')
            tgt_input_vectors = tf.nn.embedding_lookup(self._target_emb_layer,
                                                       tgt_inputs)
            if use_positional_embedding:
                tgt_input_vectors += tf.nn.embedding_lookup(
                    self.dec_pos_emb_layer, decoder_output_positions)
            tgt_input_vectors = tf.cast(
                tgt_input_vectors,
                dtype=self.params['dtype'],
            )
            # helper = tf.contrib.seq2seq.TrainingHelper(
            helper = TrainingHelper(
                inputs=tgt_input_vectors,
                sequence_length=tgt_lengths,
            )
        elif self._mode == "infer" or self._mode == "eval":
            embedding_fn = lambda ids: tf.cast(
                tf.nn.embedding_lookup(self._target_emb_layer, ids),
                dtype=self.params['dtype'],
            )
            pos_embedding_fn = None
            if use_positional_embedding:
                pos_embedding_fn = lambda ids: tf.cast(
                    tf.nn.embedding_lookup(self.dec_pos_emb_layer, ids),
                    dtype=self.params['dtype'],
                )

            # helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            helper = GreedyEmbeddingHelper(
                embedding=embedding_fn,
                start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL),
                end_token=self.END_SYMBOL,
                positional_embedding=pos_embedding_fn)

        if self._mode != "infer":
            maximum_iterations = tf.reduce_max(tgt_lengths)
        else:
            maximum_iterations = tf.reduce_max(enc_src_lengths)

        if not use_beam_search_decoder:
            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=multirnn_cell_with_attention,
                helper=helper,
                initial_state=multirnn_cell_with_attention.zero_state(
                    batch_size=self._batch_size,
                    dtype=encoder_outputs.dtype,
                ),
                output_layer=output_projection_layer,
            )
        else:
            batch_size_tensor = tf.constant(self._batch_size)
            decoder = BeamSearchDecoder(
                cell=multirnn_cell_with_attention,
                embedding=embedding_fn,
                start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]),
                end_token=self.END_SYMBOL,
                initial_state=multirnn_cell_with_attention.zero_state(
                    dtype=encoder_outputs.dtype,
                    batch_size=batch_size_tensor * self._beam_width,
                ),
                beam_width=self._beam_width,
                output_layer=output_projection_layer,
                length_penalty_weight=0.0,
            )

        final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
            decoder=decoder,
            impute_finished=self.mode != "infer",
            maximum_iterations=maximum_iterations,
        )

        if plot_attention:
            alignments = tf.transpose(final_state.alignment_history[0].stack(),
                                      [1, 0, 2])
        else:
            alignments = None

        if not use_beam_search_decoder:
            outputs = tf.argmax(final_outputs.rnn_output, axis=-1)
            logits = final_outputs.rnn_output
            return_outputs = [outputs, alignments, enc_src_lengths]
        else:
            outputs = final_outputs.predicted_ids[:, :, 0]
            logits = final_outputs.predicted_ids[:, :, 0]
            return_outputs = [outputs, enc_src_lengths]

        if self.mode == "eval":
            max_len = tf.reduce_max(tgt_lengths)
            logits = tf.while_loop(
                lambda logits: max_len > tf.shape(logits)[1],
                lambda logits: tf.concat([
                    logits,
                    tf.fill([tf.shape(logits)[0], 1,
                             tf.shape(logits)[2]],
                            tf.cast(1.0, self.params['dtype']))
                ], 1),
                loop_vars=[logits],
                back_prop=False,
            )

        return {
            'outputs': return_outputs,
            'logits': logits,
            'tgt_length': final_sequence_lengths,
        }
コード例 #5
0
ファイル: rnn_decoders.py プロジェクト: zhengjxu/OpenSeq2Seq
  def _decode(self, input_dict):
    """Decodes representation into data.

    Args:
      input_dict (dict): Python dictionary with inputs to decoder

    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
    Does not need tgt_inputs and tgt_lengths

    Returns:
      dict: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim] or
                        [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time] or
                                 [time, batch_size]
    """
    encoder_outputs = input_dict['encoder_output']['outputs']
    enc_src_lengths = input_dict['encoder_output']['src_lengths']

    self._dec_emb_w = tf.get_variable(
        name='DecoderEmbeddingMatrix',
        shape=[self._tgt_vocab_size, self._tgt_emb_size],
        dtype=tf.float32
    )

    self._output_projection_layer = tf.layers.Dense(
        self._tgt_vocab_size, use_bias=False,
    )

    if self._mode == "train":
      dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
      dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
    else:
      dp_input_keep_prob = 1.0
      dp_output_keep_prob = 1.0

    residual_connections = self.params['decoder_use_skip_connections']
    # list of cells
    self._decoder_cells = [
        single_cell(
            cell_class=self.params['core_cell'],
            cell_params=self.params.get('core_cell_params', {}),
            dp_input_keep_prob=dp_input_keep_prob,
            dp_output_keep_prob=dp_output_keep_prob,
            # residual connections are added a little differently for GNMT
            residual_connections=False if self.params['attention_type'].startswith('gnmt')
                                 else residual_connections,
        ) for _ in range(self.params['decoder_layers'])
    ]

    # pylint: disable=no-member
    tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(
        encoder_outputs,
        multiplier=self._beam_width,
    )
    # pylint: disable=no-member
    tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(
        enc_src_lengths,
        multiplier=self._beam_width,
    )
    attention_mechanism = self._build_attention(
        tiled_enc_outputs,
        tiled_enc_src_lengths,
    )

    if self.params['attention_type'].startswith('gnmt'):
      attention_cell = self._decoder_cells.pop(0)
      attention_cell = AttentionWrapper(
          attention_cell,
          attention_mechanism=attention_mechanism,
          attention_layer_size=None,  # don't use attention layer.
          output_attention=False,
          name="gnmt_attention",
      )
      attentive_decoder_cell = GNMTAttentionMultiCell(
          attention_cell,
          self._add_residual_wrapper(self._decoder_cells) if residual_connections else self._decoder_cells,
          use_new_attention=(self.params['attention_type'] == 'gnmt_v2')
      )
    else:  # non-GNMT
      attentive_decoder_cell = AttentionWrapper(
          # pylint: disable=no-member
          cell=tf.contrib.rnn.MultiRNNCell(self._decoder_cells),
          attention_mechanism=attention_mechanism,
      )
    batch_size_tensor = tf.constant(self._batch_size)
    embedding_fn = lambda ids: tf.cast(
        tf.nn.embedding_lookup(self._dec_emb_w, ids),
        dtype=self.params['dtype'],
    )
    decoder = BeamSearchDecoder(
        cell=attentive_decoder_cell,
        embedding=embedding_fn,
        start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]),
        end_token=self.END_SYMBOL,
        initial_state=attentive_decoder_cell.zero_state(
            dtype=encoder_outputs.dtype,
            batch_size=batch_size_tensor * self._beam_width,
        ),
        beam_width=self._beam_width,
        output_layer=self._output_projection_layer,
        length_penalty_weight=self._length_penalty_weight
    )

    time_major = self.params.get("time_major", False)
    use_swap_memory = self.params.get("use_swap_memory", False)
    final_outputs, final_state, final_sequence_lengths = \
        tf.contrib.seq2seq.dynamic_decode(  # pylint: disable=no-member
            decoder=decoder,
            maximum_iterations=tf.reduce_max(enc_src_lengths) * 2,
            swap_memory=use_swap_memory,
            output_time_major=time_major,
        )

    return {'logits': final_outputs.predicted_ids[:, :, 0] if not time_major else
            tf.transpose(final_outputs.predicted_ids[:, :, 0], perm=[1, 0, 2]),
            'outputs': [final_outputs.predicted_ids[:, :, 0]],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths}
コード例 #6
0
ファイル: rnn_decoders.py プロジェクト: zhengjxu/OpenSeq2Seq
  def _decode(self, input_dict):
    """Decodes representation into data.

    Args:
      input_dict (dict): Python dictionary with inputs to decoder.


    Config parameters:

    * **src_inputs** --- Decoder input Tensor of shape [batch_size, time, dim]
      or [time, batch_size, dim]
    * **src_lengths** --- Decoder input lengths Tensor of shape [batch_size]
    * **tgt_inputs** --- Only during training. labels Tensor of the
      shape [batch_size, time] or [time, batch_size].
    * **tgt_lengths** --- Only during training. labels lengths
      Tensor of the shape [batch_size].

    Returns:
      dict: Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim]
                        or [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time]
                                 or [time, batch_size]
    """
    encoder_outputs = input_dict['encoder_output']['outputs']
    enc_src_lengths = input_dict['encoder_output']['src_lengths']
    tgt_inputs = input_dict['target_tensors'][0] if 'target_tensors' in \
                                                    input_dict else None
    tgt_lengths = input_dict['target_tensors'][1] if 'target_tensors' in \
                                                     input_dict else None

    self._dec_emb_w = tf.get_variable(
        name='DecoderEmbeddingMatrix',
        shape=[self._tgt_vocab_size, self._tgt_emb_size],
        dtype=tf.float32,
    )

    self._output_projection_layer = tf.layers.Dense(
        self._tgt_vocab_size, use_bias=False,
    )

    if self._mode == "train":
      dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
      dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
    else:
      dp_input_keep_prob = 1.0
      dp_output_keep_prob = 1.0

    residual_connections = self.params['decoder_use_skip_connections']

    # list of cells
    self._decoder_cells = [
        single_cell(
            cell_class=self.params['core_cell'],
            cell_params=self.params.get('core_cell_params', {}),
            dp_input_keep_prob=dp_input_keep_prob,
            dp_output_keep_prob=dp_output_keep_prob,
            # residual connections are added a little differently for GNMT
            residual_connections=False if self.params['attention_type'].startswith('gnmt')
                                 else residual_connections,
        ) for _ in range(self.params['decoder_layers'])
    ]

    attention_mechanism = self._build_attention(
        encoder_outputs,
        enc_src_lengths,
    )
    if self.params['attention_type'].startswith('gnmt'):
      attention_cell = self._decoder_cells.pop(0)
      attention_cell = AttentionWrapper(
          attention_cell,
          attention_mechanism=attention_mechanism,
          attention_layer_size=None,
          output_attention=False,
          name="gnmt_attention",
      )
      attentive_decoder_cell = GNMTAttentionMultiCell(
          attention_cell,
          self._add_residual_wrapper(self._decoder_cells) if residual_connections else self._decoder_cells,
          use_new_attention=(self.params['attention_type'] == 'gnmt_v2'),
      )
    else:
      attentive_decoder_cell = AttentionWrapper(
          # pylint: disable=no-member
          cell=tf.contrib.rnn.MultiRNNCell(self._decoder_cells),
          attention_mechanism=attention_mechanism,
      )
    if self._mode == "train":
      input_vectors = tf.cast(
        tf.nn.embedding_lookup(self._dec_emb_w, tgt_inputs),
        dtype=self.params['dtype'],
      )
      helper = tf.contrib.seq2seq.TrainingHelper(  # pylint: disable=no-member
          inputs=input_vectors,
          sequence_length=tgt_lengths,
      )
      decoder = tf.contrib.seq2seq.BasicDecoder(  # pylint: disable=no-member
          cell=attentive_decoder_cell,
          helper=helper,
          output_layer=self._output_projection_layer,
          initial_state=attentive_decoder_cell.zero_state(
              self._batch_size, dtype=encoder_outputs.dtype,
          ),
      )
    elif self._mode == "infer" or self._mode == "eval":
      embedding_fn = lambda ids: tf.cast(
          tf.nn.embedding_lookup(self._dec_emb_w, ids),
          dtype=self.params['dtype'],
      )
      # pylint: disable=no-member
      helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
          embedding=embedding_fn,
          start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL),
          end_token=self.END_SYMBOL,
      )
      decoder = tf.contrib.seq2seq.BasicDecoder(  # pylint: disable=no-member
          cell=attentive_decoder_cell,
          helper=helper,
          initial_state=attentive_decoder_cell.zero_state(
              batch_size=self._batch_size, dtype=encoder_outputs.dtype,
          ),
          output_layer=self._output_projection_layer,
      )
    else:
      raise ValueError(
          "Unknown mode for decoder: {}".format(self._mode)
      )

    time_major = self.params.get("time_major", False)
    use_swap_memory = self.params.get("use_swap_memory", False)
    if self._mode == 'train':
      maximum_iterations = tf.reduce_max(tgt_lengths)
    else:
      maximum_iterations = tf.reduce_max(enc_src_lengths) * 2

    # pylint: disable=no-member
    final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
        decoder=decoder,
        impute_finished=True,
        maximum_iterations=maximum_iterations,
        swap_memory=use_swap_memory,
        output_time_major=time_major,
    )

    return {'logits': final_outputs.rnn_output if not time_major else
            tf.transpose(final_outputs.rnn_output, perm=[1, 0, 2]),
            'outputs': [tf.argmax(final_outputs.rnn_output, axis=-1)],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths}