def _build_encoder(self): with tf.variable_scope('encoder'): if self.hparams.encoder_type == 'pbilstm': cells_fw = [model_utils.single_cell("lstm", self.hparams.encoder_num_units // 2, self.mode) for _ in range(self.hparams.num_encoder_layers)] cells_bw = [model_utils.single_cell("lstm", self.hparams.encoder_num_units // 2, self.mode) for _ in range(self.hparams.num_encoder_layers)] prev_layer = self.inputs prev_seq_len = self.input_seq_len with tf.variable_scope("stack_p_bidirectional_rnn"): state_fw = state_bw = None for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)): initial_state_fw = None initial_state_bw = None size = tf.cast(tf.floor(tf.shape(prev_layer)[1] / 2), tf.int32) prev_layer = prev_layer[:, :size * 2, :] prev_layer = tf.reshape(prev_layer, [tf.shape(prev_layer)[0], size, prev_layer.shape[2] * 2]) prev_seq_len = tf.cast(tf.floor(prev_seq_len / 2), tf.int32) with tf.variable_scope("cell_%d" % i): outputs, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( cell_fw, cell_bw, prev_layer, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, sequence_length=prev_seq_len, dtype=tf.float32 ) # Concat the outputs to create the new input. prev_layer = tf.concat(outputs, axis=2) # states_fw.append(state_fw) # states_bw.append(state_bw) return prev_layer, tf.concat([state_fw, state_bw], -1) if self.hparams.encoder_type == 'bilstm': cells_fw = [model_utils.single_cell("lstm", self.hparams.encoder_num_units // 2, self.mode) for _ in range(self.hparams.num_encoder_layers)] cells_bw = [model_utils.single_cell("lstm", self.hparams.encoder_num_units // 2, self.mode) for _ in range(self.hparams.num_encoder_layers)] outputs, output_states_fw, output_states_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, self.inputs, sequence_length=self.input_seq_len, dtype=tf.float32) return outputs, tf.concat([output_states_fw, output_states_bw], -1) elif self.hparams.encoder_type == 'lstm': cells = [model_utils.single_cell("lstm", self.hparams.encoder_num_units, self.mode) for _ in range(self.hparams.num_encoder_layers)] cell = tf.nn.rnn_cell.MultiRNNCell(cells) outputs, state = tf.nn.dynamic_rnn(cell, self.inputs, sequence_length=self.input_seq_len, dtype=tf.float32) return outputs, state
def _get_decoder_cell(self): if self.hparams.num_decoder_layers == 1: decoder_cell = model_utils.single_cell( "lstm", self.hparams.decoder_num_units, self.mode) else: cells = [model_utils.single_cell("lstm", self.hparams.decoder_num_units, self.mode) for _ in range(self.hparams.num_decoder_layers)] decoder_cell = tf.nn.rnn_cell.MultiRNNCell(cells) return decoder_cell
def _build_da_word_encoder(self, inputs, input_seq_len): #if True: with tf.variable_scope('encoder') as encoder_scope: if self.hparams.da_word_encoder_type == 'bilstm': cells_fw = [ model_utils.single_cell( "lstm", self.hparams.da_word_encoder_num_units // 2, self.mode, dropout=self.hparams.dropout) for _ in range(self.hparams.num_da_word_encoder_layers) ] cells_bw = [ model_utils.single_cell( "lstm", self.hparams.da_word_encoder_num_units // 2, self.mode, dropout=self.hparams.dropout) for _ in range(self.hparams.num_da_word_encoder_layers) ] outputs, output_states_fw, output_states_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, inputs, sequence_length=input_seq_len, dtype=tf.float32) return outputs, tf.concat([output_states_fw, output_states_bw], -1) elif self.hparams.da_word_encoder_type == 'lstm': cells = [ model_utils.single_cell( "lstm", self.hparams.da_word_encoder_num_units, self.mode) for _ in range(self.hparams.num_da_word_encoder_layers) ] cell = tf.nn.rnn_cell.MultiRNNCell(cells) outputs, state = tf.nn.dynamic_rnn( cell, inputs, sequence_length=input_seq_len, dtype=tf.float32) return outputs, state
def _build_utt_encoder(self, history_inputs, history_input_seq_len): # encoder at utterance-level history_encoder_outputs, history_encoder_state = tf.nn.dynamic_rnn( model_utils.single_cell("lstm", self.hparams.utt_encoder_num_units, self.mode, dropout=self.hparams.dropout), history_inputs, sequence_length=history_input_seq_len, dtype=tf.float32) return history_encoder_state[0]