Ejemplo n.º 1
0
def contextual_bi_rnn(tensor_rep,
                      mask_rep,
                      hn,
                      cell_type,
                      only_final=False,
                      wd=0.,
                      keep_prob=1.,
                      is_train=None,
                      scope=None):
    """
    fusing contextual information using bi-direction rnn
    :param tensor_rep: [..., sl, vec]
    :param mask_rep: [..., sl]
    :param hn:
    :param cell_type: 'gru', 'lstm', basic_lstm' and 'basic_rnn'
    :param only_final: True or False
    :param wd:
    :param keep_prob:
    :param is_train:
    :param scope:
    :return:
    """
    with tf.variable_scope(scope or 'contextual_bi_rnn'):  # correct
        reuse = None if not tf.get_variable_scope().reuse else True
        #print(reuse)
        if cell_type == 'gru':
            cell_fw = tf.contrib.rnn.GRUCell(hn, reuse=reuse)
            cell_bw = tf.contrib.rnn.GRUCell(hn, reuse=reuse)
        elif cell_type == 'lstm':
            cell_fw = tf.contrib.rnn.LSTMCell(hn, reuse=reuse)
            cell_bw = tf.contrib.rnn.LSTMCell(hn, reuse=reuse)
        elif cell_type == 'basic_lstm':
            cell_fw = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse)
            cell_bw = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse)
        elif cell_type == 'basic_rnn':
            cell_fw = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse)
            cell_bw = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse)
        else:
            raise AttributeError('no cell type \'%s\'' % cell_type)
        cell_dp_fw = SwitchableDropoutWrapper(cell_fw, is_train, keep_prob)
        cell_dp_bw = SwitchableDropoutWrapper(cell_bw, is_train, keep_prob)

        tensor_len = tf.reduce_sum(tf.cast(mask_rep, tf.int32), -1)  # [bs]

        (outputs_fw,
         output_bw), _ = bidirectional_dynamic_rnn(cell_dp_fw,
                                                   cell_dp_bw,
                                                   tensor_rep,
                                                   tensor_len,
                                                   dtype=tf.float32)
        rnn_outputs = tf.concat([outputs_fw, output_bw], -1)  # [...,sl,2hn]

        if wd > 0:
            add_reg_without_bias()
        if not only_final:
            return rnn_outputs  # [....,sl, 2hn]
        else:
            return get_last_state(rnn_outputs, mask_rep)  # [...., 2hn]
Ejemplo n.º 2
0
 def do_reduce(self, data_for_reduce, mask_for_reduce):
     with tf.variable_scope('sr_%s' % self.method_type):
         seq_len = tf.reduce_sum(tf.cast(mask_for_reduce, tf.int32), -1)
         (fw, bw), _ = tf.nn.bidirectional_dynamic_rnn(
             self.cell_dp_fw,
             self.cell_dp_bw,
             data_for_reduce,
             seq_len,
             dtype=tf.float32,
             scope='shift_reduce_bilstm_loop')
         value = tf.concat([fw, bw], -1)
         processed_reduced_data = get_last_state(value, mask_for_reduce)
         return processed_reduced_data
Ejemplo n.º 3
0
def one_direction_rnn(tensor_rep,
                      mask_rep,
                      hn,
                      cell_type,
                      only_final=False,
                      wd=0.,
                      keep_prob=1.,
                      is_train=None,
                      is_forward=True,
                      scope=None):
    assert not is_forward  # todo: waiting to be implemented
    with tf.variable_scope(scope or '%s_rnn' %
                           'forward' if is_forward else 'backward'):
        reuse = None if not tf.get_variable_scope().reuse else True
        # print(reuse)
        if cell_type == 'gru':
            cell = tf.contrib.rnn.GRUCell(hn, reuse=reuse)
        elif cell_type == 'lstm':
            cell = tf.contrib.rnn.LSTMCell(hn, reuse=reuse)
        elif cell_type == 'basic_lstm':
            cell = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse)
        elif cell_type == 'basic_rnn':
            cell = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse)
        else:
            raise AttributeError('no cell type \'%s\'' % cell_type)
        cell_dp = SwitchableDropoutWrapper(cell, is_train, keep_prob)

        tensor_len = tf.reduce_sum(tf.cast(mask_rep, tf.int32), -1)  # [bs]

        rnn_outputs, _ = dynamic_rnn(cell_dp,
                                     tensor_rep,
                                     tensor_len,
                                     dtype=tf.float32)

        if wd > 0:
            add_reg_without_bias()
        if not only_final:
            return rnn_outputs  # [....,sl, 2hn]
        else:
            return get_last_state(rnn_outputs, mask_rep)  # [...., 2hn]