Пример #1
0
    def _build_encoder(self):
        with tf.variable_scope('encoder'):
            if self.hparams.encoder_type == 'pbilstm':
                cells_fw = [model_utils.single_cell("lstm", self.hparams.encoder_num_units // 2, self.mode) for _ in
                            range(self.hparams.num_encoder_layers)]
                cells_bw = [model_utils.single_cell("lstm", self.hparams.encoder_num_units // 2, self.mode) for _ in
                            range(self.hparams.num_encoder_layers)]

                prev_layer = self.inputs
                prev_seq_len = self.input_seq_len

                with tf.variable_scope("stack_p_bidirectional_rnn"):
                    state_fw = state_bw = None
                    for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
                        initial_state_fw = None
                        initial_state_bw = None

                        size = tf.cast(tf.floor(tf.shape(prev_layer)[1] / 2), tf.int32)
                        prev_layer = prev_layer[:, :size * 2, :]
                        prev_layer = tf.reshape(prev_layer,
                                                [tf.shape(prev_layer)[0], size, prev_layer.shape[2] * 2])
                        prev_seq_len = tf.cast(tf.floor(prev_seq_len / 2), tf.int32)

                        with tf.variable_scope("cell_%d" % i):
                            outputs, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
                                cell_fw,
                                cell_bw,
                                prev_layer,
                                initial_state_fw=initial_state_fw,
                                initial_state_bw=initial_state_bw,
                                sequence_length=prev_seq_len,
                                dtype=tf.float32
                            )
                            # Concat the outputs to create the new input.
                            prev_layer = tf.concat(outputs, axis=2)
                        # states_fw.append(state_fw)
                        # states_bw.append(state_bw)

                return prev_layer, tf.concat([state_fw, state_bw], -1)
            if self.hparams.encoder_type == 'bilstm':
                cells_fw = [model_utils.single_cell("lstm", self.hparams.encoder_num_units // 2, self.mode) for _ in
                            range(self.hparams.num_encoder_layers)]
                cells_bw = [model_utils.single_cell("lstm", self.hparams.encoder_num_units // 2, self.mode) for _ in
                            range(self.hparams.num_encoder_layers)]
                outputs, output_states_fw, output_states_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
                    cells_fw, cells_bw,
                    self.inputs, sequence_length=self.input_seq_len,
                    dtype=tf.float32)
                return outputs, tf.concat([output_states_fw, output_states_bw], -1)
            elif self.hparams.encoder_type == 'lstm':
                cells = [model_utils.single_cell("lstm", self.hparams.encoder_num_units, self.mode) for _ in
                         range(self.hparams.num_encoder_layers)]
                cell = tf.nn.rnn_cell.MultiRNNCell(cells)
                outputs, state = tf.nn.dynamic_rnn(cell, self.inputs, sequence_length=self.input_seq_len,
                                                   dtype=tf.float32)
                return outputs, state
Пример #2
0
    def _get_decoder_cell(self):
        if self.hparams.num_decoder_layers == 1:
            decoder_cell = model_utils.single_cell(
                "lstm", self.hparams.decoder_num_units, self.mode)
        else:
            cells = [model_utils.single_cell("lstm", self.hparams.decoder_num_units, self.mode) for _ in
                     range(self.hparams.num_decoder_layers)]
            decoder_cell = tf.nn.rnn_cell.MultiRNNCell(cells)

        return decoder_cell
Пример #3
0
 def _build_da_word_encoder(self, inputs, input_seq_len):
     #if True:
     with tf.variable_scope('encoder') as encoder_scope:
         if self.hparams.da_word_encoder_type == 'bilstm':
             cells_fw = [
                 model_utils.single_cell(
                     "lstm",
                     self.hparams.da_word_encoder_num_units // 2,
                     self.mode,
                     dropout=self.hparams.dropout)
                 for _ in range(self.hparams.num_da_word_encoder_layers)
             ]
             cells_bw = [
                 model_utils.single_cell(
                     "lstm",
                     self.hparams.da_word_encoder_num_units // 2,
                     self.mode,
                     dropout=self.hparams.dropout)
                 for _ in range(self.hparams.num_da_word_encoder_layers)
             ]
             outputs, output_states_fw, output_states_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
                 cells_fw,
                 cells_bw,
                 inputs,
                 sequence_length=input_seq_len,
                 dtype=tf.float32)
             return outputs, tf.concat([output_states_fw, output_states_bw],
                                       -1)
         elif self.hparams.da_word_encoder_type == 'lstm':
             cells = [
                 model_utils.single_cell(
                     "lstm", self.hparams.da_word_encoder_num_units,
                     self.mode)
                 for _ in range(self.hparams.num_da_word_encoder_layers)
             ]
             cell = tf.nn.rnn_cell.MultiRNNCell(cells)
             outputs, state = tf.nn.dynamic_rnn(
                 cell,
                 inputs,
                 sequence_length=input_seq_len,
                 dtype=tf.float32)
             return outputs, state
Пример #4
0
 def _build_utt_encoder(self, history_inputs, history_input_seq_len):
     # encoder at utterance-level
     history_encoder_outputs, history_encoder_state = tf.nn.dynamic_rnn(
         model_utils.single_cell("lstm",
                                 self.hparams.utt_encoder_num_units,
                                 self.mode,
                                 dropout=self.hparams.dropout),
         history_inputs,
         sequence_length=history_input_seq_len,
         dtype=tf.float32)
     return history_encoder_state[0]