def __call__(self, inputs, seq_len, return_last_state=False, time_major=False): assert not time_major, "BiRNN class cannot support time_major currently" with tf.variable_scope(self.scope): flat_inputs = flatten(inputs, keep=2) # reshape to [-1, max_time, dim] seq_len = flatten( seq_len, keep=0) # reshape to [x] (one dimension sequence) outputs, ((_, h_fw), (_, h_bw)) = bidirectional_dynamic_rnn( self.cell_fw, self.cell_bw, flat_inputs, sequence_length=seq_len, dtype=tf.float32) if return_last_state: # return last states output = tf.concat([h_fw, h_bw], axis=-1) # shape = [-1, 2 * num_units] output = reconstruct( output, ref=inputs, keep=2, remove_shape=1) # remove the max_time shape else: output = tf.concat( outputs, axis=-1) # shape = [-1, max_time, 2 * num_units] output = reconstruct( output, ref=inputs, keep=2 ) # reshape to same as inputs, except the last two dim return output
def __call__(self, inputs, seq_len, return_last_state=False, time_major=False): assert not time_major, "StackBiRNN class cannot support time_major currently" with tf.variable_scope(self.scope): flat_inputs = flatten(inputs, keep=2) # reshape to [-1, max_time, dim] seq_len = flatten( seq_len, keep=0) # reshape to [x] (one dimension sequence) outputs, states_fw, states_bw = stack_bidirectional_dynamic_rnn( self.cells_fw, self.cells_fw, flat_inputs, sequence_length=seq_len, dtype=tf.float32) if return_last_state: # return last states # since states_fw is the final states, one tensor per layer, of the forward rnn and states_bw is the # final states, one tensor per layer, of the backward rnn, here we extract the last layer of forward # and backward states as last state h_fw, h_bw = states_fw[self.num_layers - 1].h, states_bw[self.num_layers - 1].h output = tf.concat([h_fw, h_bw], axis=-1) # shape = [-1, 2 * num_units] output = reconstruct( output, ref=inputs, keep=2, remove_shape=1) # remove the max_time shape else: output = tf.concat( outputs, axis=-1) # shape = [-1, max_time, 2 * num_units] output = reconstruct( output, ref=inputs, keep=2 ) # reshape to same as inputs, except the last two dim return output
def __call__(self, inputs, seq_len, time_major=False): assert not time_major, "DenseConnectBiRNN class cannot support time_major currently" # this function does not support return_last_state method currently with tf.variable_scope(self.scope): flat_inputs = flatten(inputs, keep=2) # reshape to [-1, max_time, dim] seq_len = flatten( seq_len, keep=0) # reshape to [x] (one dimension sequence) cur_inputs = flat_inputs for i in range(self.num_layers): cur_outputs = self.dense_bi_rnn[i](cur_inputs, seq_len) if i < self.num_layers - 1: cur_inputs = tf.concat([cur_inputs, cur_outputs], axis=-1) else: cur_inputs = cur_outputs output = reconstruct(cur_inputs, ref=inputs, keep=2) return output