def contextual_bi_rnn(tensor_rep, mask_rep, hn, cell_type, only_final=False, wd=0., keep_prob=1., is_train=None, scope=None): """ fusing contextual information using bi-direction rnn :param tensor_rep: [..., sl, vec] :param mask_rep: [..., sl] :param hn: :param cell_type: 'gru', 'lstm', basic_lstm' and 'basic_rnn' :param only_final: True or False :param wd: :param keep_prob: :param is_train: :param scope: :return: """ with tf.variable_scope(scope or 'contextual_bi_rnn'): # correct reuse = None if not tf.get_variable_scope().reuse else True #print(reuse) if cell_type == 'gru': cell_fw = tf.contrib.rnn.GRUCell(hn, reuse=reuse) cell_bw = tf.contrib.rnn.GRUCell(hn, reuse=reuse) elif cell_type == 'lstm': cell_fw = tf.contrib.rnn.LSTMCell(hn, reuse=reuse) cell_bw = tf.contrib.rnn.LSTMCell(hn, reuse=reuse) elif cell_type == 'basic_lstm': cell_fw = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse) cell_bw = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse) elif cell_type == 'basic_rnn': cell_fw = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse) cell_bw = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse) else: raise AttributeError('no cell type \'%s\'' % cell_type) cell_dp_fw = SwitchableDropoutWrapper(cell_fw, is_train, keep_prob) cell_dp_bw = SwitchableDropoutWrapper(cell_bw, is_train, keep_prob) tensor_len = tf.reduce_sum(tf.cast(mask_rep, tf.int32), -1) # [bs] (outputs_fw, output_bw), _ = bidirectional_dynamic_rnn(cell_dp_fw, cell_dp_bw, tensor_rep, tensor_len, dtype=tf.float32) rnn_outputs = tf.concat([outputs_fw, output_bw], -1) # [...,sl,2hn] if wd > 0: add_reg_without_bias() if not only_final: return rnn_outputs # [....,sl, 2hn] else: return get_last_state(rnn_outputs, mask_rep) # [...., 2hn]
def do_reduce(self, data_for_reduce, mask_for_reduce): with tf.variable_scope('sr_%s' % self.method_type): seq_len = tf.reduce_sum(tf.cast(mask_for_reduce, tf.int32), -1) (fw, bw), _ = tf.nn.bidirectional_dynamic_rnn( self.cell_dp_fw, self.cell_dp_bw, data_for_reduce, seq_len, dtype=tf.float32, scope='shift_reduce_bilstm_loop') value = tf.concat([fw, bw], -1) processed_reduced_data = get_last_state(value, mask_for_reduce) return processed_reduced_data
def one_direction_rnn(tensor_rep, mask_rep, hn, cell_type, only_final=False, wd=0., keep_prob=1., is_train=None, is_forward=True, scope=None): assert not is_forward # todo: waiting to be implemented with tf.variable_scope(scope or '%s_rnn' % 'forward' if is_forward else 'backward'): reuse = None if not tf.get_variable_scope().reuse else True # print(reuse) if cell_type == 'gru': cell = tf.contrib.rnn.GRUCell(hn, reuse=reuse) elif cell_type == 'lstm': cell = tf.contrib.rnn.LSTMCell(hn, reuse=reuse) elif cell_type == 'basic_lstm': cell = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse) elif cell_type == 'basic_rnn': cell = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse) else: raise AttributeError('no cell type \'%s\'' % cell_type) cell_dp = SwitchableDropoutWrapper(cell, is_train, keep_prob) tensor_len = tf.reduce_sum(tf.cast(mask_rep, tf.int32), -1) # [bs] rnn_outputs, _ = dynamic_rnn(cell_dp, tensor_rep, tensor_len, dtype=tf.float32) if wd > 0: add_reg_without_bias() if not only_final: return rnn_outputs # [....,sl, 2hn] else: return get_last_state(rnn_outputs, mask_rep) # [...., 2hn]