Beispiel #1
0
 def __call__(self,
              inputs,
              seq_len,
              return_last_state=False,
              keep_prob=None,
              is_train=None):
     with tf.variable_scope(self.scope):
         if return_last_state:
             _, ((_, output_fw), (_,
                                  output_bw)) = _bidirectional_dynamic_rnn(
                                      self.cell_fw,
                                      self.cell_bw,
                                      inputs,
                                      sequence_length=seq_len,
                                      dtype=tf.float32)
             output = tf.concat([output_fw, output_bw], axis=-1)
         else:
             (output_fw, output_bw), _ = _bidirectional_dynamic_rnn(
                 self.cell_fw,
                 self.cell_bw,
                 inputs,
                 sequence_length=seq_len,
                 dtype=tf.float32)
             output = tf.concat([output_fw, output_bw], axis=-1)
         output = dropout(output, keep_prob, is_train)
         return output
Beispiel #2
0
def bidirectional_dynamic_rnn(cell_fw,
                              cell_bw,
                              inputs,
                              sequence_length=None,
                              initial_state_fw=None,
                              initial_state_bw=None,
                              dtype=None,
                              parallel_iterations=None,
                              swap_memory=False,
                              time_major=False,
                              scope=None):
    assert not time_major

    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_len = None if sequence_length is None else tf.cast(
        flatten(sequence_length, 0), 'int64')

    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                                   initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                                   dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                   time_major=time_major, scope=scope)

    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    # FIXME : final state is not reshaped!
    return (fw_outputs, bw_outputs), final_state
Beispiel #3
0
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
                              initial_state_fw=None, initial_state_bw=None,
                              dtype=None, parallel_iterations=None,
                              swap_memory=False, time_major=False, scope=None):
    assert not time_major

    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    #flat_inputs = flatten2(inputs)  # [-1, J, d]
    #tmpshape = tf.shape(inputs)
   # flat_inputs = tf.reshape(inputs,(tmpshape[0],-1,tmpshape[-1]))
    #flat_inputs = tf.reshape(inputs, (tmpshape[0], , tmpshape[len(tmpshape)-1]))
    flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')
    print ("inputs==={}   |||   flat_inputs==={}  |||  flat_len:{}".format(inputs,flat_inputs,flat_len))
    # (flat_fw_outputs, flat_bw_outputs), final_state = \
    #     tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
    #                                initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
    #                                dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
    #                                time_major=time_major, scope=scope)
    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                                   initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                                   dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                   time_major=time_major, scope=scope)
    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    # FIXME : final state is not reshaped!
    return (fw_outputs, bw_outputs), final_state
def bidirectional_dynamic_rnn(cell_fw,
                              cell_bw,
                              inputs,
                              sequence_length=None,
                              initial_state_fw=None,
                              initial_state_bw=None,
                              dtype=None,
                              parallel_iterations=None,
                              swap_memory=False,
                              time_major=False,
                              scope=None):
    assert not time_major

    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    # flat_inputs = inputs
    print "flat inputs shape", flat_inputs.shape
    flat_len = None if sequence_length is None else tf.cast(
        flatten(sequence_length, 0), 'int64')

    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                                   initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                                   dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                   time_major=time_major, scope=scope)

    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    # FIXME : final state is not reshaped!
    return (fw_outputs, bw_outputs), final_state


# def bidirectional_rnn(cell_fw, cell_bw, inputs,
#                       initial_state_fw=None, initial_state_bw=None,
#                       dtype=None, sequence_length=None, scope=None):

#     flat_inputs = flatten(inputs, 1)  # [-1, J, d]
#     flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')

#     (flat_fw_outputs, flat_bw_outputs), final_state = \
#         _bidirectional_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
#                            initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
#                            dtype=dtype, scope=scope)

#     fw_outputs = reconstruct(flat_fw_outputs, inputs, 1)
#     bw_outputs = reconstruct(flat_bw_outputs, inputs, 1)
#     # FIXME : final state is not reshaped!
#     return (fw_outputs, bw_outputs), final_state

# def build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size):
#     """
#     Create the LSTM layers
#     """
#     lstms = [tf.contrib.rnn.BasicLSTMCell(size) for size in lstm_sizes]
#     # Add dropout to the cell
#     drops = [tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob_) for lstm in lstms]
#     # Stack up multiple LSTM layers, for deep learning
#     cell = tf.contrib.rnn.MultiRNNCell(drops)
# # Getting an initial state of all zeros
#     initial_state = cell.zero_state(batch_size, tf.float32)
#     lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, embed, initial_state=initial_state)
Beispiel #5
0
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
                              initial_state_fw=None, initial_state_bw=None,
                              dtype=None, parallel_iterations=None,
                              swap_memory=False, time_major=False, scope=None):
    assert not time_major
    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')

    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                                   initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                                   dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                   time_major=time_major, scope=scope)

    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    # FIXME : final state is not reshaped!
    (_, flat_final_hidden_fw), (_, flat_final_hidden_bw) = final_state

    def my_reconstruct(tensor, ref):
        tensor_shape = tensor.get_shape().as_list()
        ref_shape = ref.get_shape().as_list()
        pre_shape = [ref_shape[i] or tf.shape(ref)[i] for i in range(len(ref_shape) - 2)]
        keep_shape = tensor_shape[-1:]
        target_shape = pre_shape + keep_shape
        out = tf.reshape(tensor, target_shape)
        return out

    final_hidden_fw = my_reconstruct(flat_final_hidden_fw, inputs)
    final_hidden_bw = my_reconstruct(flat_final_hidden_bw, inputs)
    return (fw_outputs, bw_outputs), (final_hidden_fw, final_hidden_bw)
Beispiel #6
0
def my_bidirectional_dynamic_rnn(cell_fw,
                                 cell_bw,
                                 inputs,
                                 sequence_length=None,
                                 initial_state_fw=None,
                                 initial_state_bw=None,
                                 dtype=None,
                                 parallel_iterations=None,
                                 swap_memory=False,
                                 time_major=False,
                                 scope=None):
    assert not time_major

    flat_inputs = my_flatten(inputs)
    flat_inputs = tf.Print(flat_inputs, [tf.shape(flat_inputs)],
                           message="flat inputs g1 shape:",
                           first_n=5)
    flat_len = None if sequence_length is None else tf.cast(
        my_flatten_2(sequence_length), 'int64')
    flat_len = tf.Print(flat_len, [flat_len],
                        message="flat lebgths of g1:",
                        first_n=5)

    final_output, final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                                   initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                                   dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                   time_major=time_major, scope=scope)

    return final_output, final_state, flat_len
Beispiel #7
0
def my_new_bidirectional_dynamic_rnn(cell_fw,
                                     cell_bw,
                                     inputs,
                                     sequence_length=None,
                                     initial_state_fw=None,
                                     initial_state_bw=None,
                                     dtype=None,
                                     parallel_iterations=None,
                                     swap_memory=False,
                                     time_major=False,
                                     scope=None):
    assert not time_major

    #flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_inputs = my_flatten(inputs)
    #flat_inputs = tf.reshape(inputs,[tf.shape(inputs)[0],-1,tf.shape(inputs)[3]])
    flat_len = sequence_length
    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                                   initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                                   dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                   time_major=time_major, scope=scope)

    fw_outputs = reconstruct(flat_fw_outputs, inputs, 1)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 1)
    # FIXME : final state is not reshaped!
    return (fw_outputs, bw_outputs), final_state
Beispiel #8
0
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, scope=None):
    flat_inputs = flatten(inputs, 2)  # [-1, seq_len, dim]
    flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), dtype=tf.int64)
    (flat_fw_outputs, flat_bw_outputs), final_state = _bidirectional_dynamic_rnn(
        cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, dtype=tf.float32, scope=scope)
    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    return (fw_outputs, bw_outputs), final_state
Beispiel #9
0
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
                              initial_state_fw=None, initial_state_bw=None,
                              dtype=None, parallel_iterations=None,
                              swap_memory=False, time_major=False, scope=None):
    assert not time_major

    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')

    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                                   initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                                   dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                   time_major=time_major, scope=scope)

    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    # FIXME : final state is not reshaped!
    return (fw_outputs, bw_outputs), final_state
Beispiel #10
0
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
                              initial_state_fw=None, initial_state_bw=None,
                              dtype=None, parallel_iterations=None,
                              swap_memory=False, time_major=False, scope=None):
    """

    :param cell_fw:
    :param cell_bw:
    :param inputs: [60, ? , 200]
    :param sequence_length: shape = [60]
    :param initial_state_fw:
    :param initial_state_bw:
    :param dtype: float
    :param parallel_iterations:
    :param swap_memory: false
    :param time_major: false
    :param scope: u1
    :return: the hidden for each state, both direction, and final state for both direction
    """
    assert not time_major

    flat_inputs = flatten(inputs, 2)  # [-1, J, d] [60, ?, 200]
    flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')

    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs,  # 200,
                                   sequence_length=flat_len,  # question mark
                                   initial_state_fw=initial_state_fw,
                                   initial_state_bw=initial_state_bw,
                                   dtype=dtype,
                                   parallel_iterations=parallel_iterations,
                                   swap_memory=swap_memory,
                                   time_major=time_major,
                                   scope=scope)

    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    # FIXME : final state is not reshaped!
    return (fw_outputs, bw_outputs), final_state
Beispiel #11
0
    def decode(self, knowledge_rep, masks, is_train, hparams):
        """
        takes in a knowledge representation
        and output a probability estimation over
        all paragraph tokens on which token should be
        the start of the answer span, and which should be
        the end of the answer span.

        :param knowledge_rep: it is a representation of the paragraph and question,
                              decided by how you choose to implement the encoder
        :return:
        """
        p0 = knowledge_rep
        p_mask, q_mask = masks
        batch_size = hparams.batch_size
        input_keep_prob = hparams.input_keep_prob

        p_len = tf.reduce_sum(tf.cast(p_mask, 'int32'), 1)  # [N]
        q_len = tf.reduce_sum(tf.cast(q_mask, 'int32'), 1)  # [N]

        JX = tf.shape(p_mask)[1]

        with tf.variable_scope("main"):
            cell = BasicLSTMCell(self.state_size, state_is_tuple=True)
            first_cell = SwitchableDropoutWrapper(cell, is_train, input_keep_prob=input_keep_prob)

            # [N, JX, 2d]
            (fw_g0, bw_g0), _ = _bidirectional_dynamic_rnn(first_cell, first_cell, p0, 
                                                           p_len, dtype='float', scope='g0') 
            g0 = tf.concat([fw_g0, bw_g0], 2)
            
            cell = BasicLSTMCell(self.state_size, state_is_tuple=True)
            first_cell = SwitchableDropoutWrapper(cell, is_train, input_keep_prob=input_keep_prob)

            # [N, JX, 2d]
            (fw_g1, bw_g1), _ = _bidirectional_dynamic_rnn(first_cell, first_cell, g0, 
                                                           p_len, dtype='float', scope='g1')  
            g1 = tf.concat([fw_g1, bw_g1], 2)
            logits = linear_logits([g1, p0], self.state_size, 0.0, scope='logits1', 
                                   mask=p_mask, is_train=is_train)

            # TODO use batch _size
            a1i = softsel(tf.reshape(g1, [batch_size, JX, 2 * self.state_size]), 
                          tf.reshape(logits, [batch_size, JX]))

            a1i = tf.tile(tf.expand_dims(a1i, 1), [1, JX, 1])
            
            flat_logits1 = tf.reshape(logits, [-1, JX])
            flat_yp = tf.nn.softmax(flat_logits1)  # [-1, M*JX]
            yp1 = tf.reshape(flat_yp, [-1, JX])

            cell = BasicLSTMCell(self.state_size, state_is_tuple=True)
            d_cell = SwitchableDropoutWrapper(cell, is_train, input_keep_prob=input_keep_prob)

            # [N, M, JX, 2d]
            (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn(d_cell, d_cell, 
                                                          tf.concat([p0, g1, a1i, g1 * a1i], 2),
                                                          p_len, dtype='float', scope='g2') 
            g2 = tf.concat([fw_g2, bw_g2], 2)
            logits2 = linear_logits([g2, p0], self.state_size, 0.0, scope='logits2', 
                                    mask=p_mask, is_train=is_train)
            flat_logits2 = tf.reshape(logits2, [-1, JX])
            flat_yp = tf.nn.softmax(flat_logits2)  # [-1, M*JX]
            yp2 = tf.reshape(flat_yp, [-1, JX])
        return (yp1, flat_logits1), (yp2, flat_logits2)
                              name='enc_embedding')
output_embedding = tf.Variable(tf.random_uniform((len(char2numY), embed_size),
                                                 -1.0, 1.0),
                               name='dec_embedding')
date_input_embed = tf.nn.embedding_lookup(input_embedding, inputs)
date_output_embed = tf.nn.embedding_lookup(output_embedding, outputs)

with tf.variable_scope("encoding") as encoding_scope:
    lstm_enc = tf.contrib.rnn.BasicLSTMCell(nodes)
    #encoder_output, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=date_input_embed, dtype=tf.float32)
    ((encoder_fw_outputs, encoder_bw_outputs),
     (encoder_fw_final_state,
      encoder_bw_final_state)) = (_bidirectional_dynamic_rnn(
          cell_fw=lstm_enc,
          cell_bw=lstm_enc,
          inputs=date_input_embed,
          sequence_length=tf.fill([batch_size], x_seq_length),
          dtype=tf.float32,
          time_major=False))

encoder_output = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)
encoder_final_state_c = tf.concat(
    (encoder_fw_final_state.c, encoder_bw_final_state.c), 1)
encoder_final_state_h = tf.concat(
    (encoder_fw_final_state.h, encoder_bw_final_state.h), 1)
encoder_final_state = tf.contrib.rnn.LSTMStateTuple(c=encoder_final_state_c,
                                                    h=encoder_final_state_h)
"""
with tf.variable_scope("decoding") as decoding_scope:
    lstm_dec = tf.contrib.rnn.BasicLSTMCell(nodes)
    dec_outputs, _ =tf.nn.dynamic_rnn(lstm_dec, inputs=date_output_embed, dtype=tf.float32,initial_state=last_state)