def __call__(self, inputs, seq_len, return_last_state=False, keep_prob=None, is_train=None): with tf.variable_scope(self.scope): if return_last_state: _, ((_, output_fw), (_, output_bw)) = _bidirectional_dynamic_rnn( self.cell_fw, self.cell_bw, inputs, sequence_length=seq_len, dtype=tf.float32) output = tf.concat([output_fw, output_bw], axis=-1) else: (output_fw, output_bw), _ = _bidirectional_dynamic_rnn( self.cell_fw, self.cell_bw, inputs, sequence_length=seq_len, dtype=tf.float32) output = tf.concat([output_fw, output_bw], axis=-1) output = dropout(output, keep_prob, is_train) return output
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major flat_inputs = flatten(inputs, 2) # [-1, J, d] flat_len = None if sequence_length is None else tf.cast( flatten(sequence_length, 0), 'int64') (flat_fw_outputs, flat_bw_outputs), final_state = \ _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) # FIXME : final state is not reshaped! return (fw_outputs, bw_outputs), final_state
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major flat_inputs = flatten(inputs, 2) # [-1, J, d] #flat_inputs = flatten2(inputs) # [-1, J, d] #tmpshape = tf.shape(inputs) # flat_inputs = tf.reshape(inputs,(tmpshape[0],-1,tmpshape[-1])) #flat_inputs = tf.reshape(inputs, (tmpshape[0], , tmpshape[len(tmpshape)-1])) flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') print ("inputs==={} ||| flat_inputs==={} ||| flat_len:{}".format(inputs,flat_inputs,flat_len)) # (flat_fw_outputs, flat_bw_outputs), final_state = \ # tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, # initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, # dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, # time_major=time_major, scope=scope) (flat_fw_outputs, flat_bw_outputs), final_state = \ _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) # FIXME : final state is not reshaped! return (fw_outputs, bw_outputs), final_state
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major flat_inputs = flatten(inputs, 2) # [-1, J, d] # flat_inputs = inputs print "flat inputs shape", flat_inputs.shape flat_len = None if sequence_length is None else tf.cast( flatten(sequence_length, 0), 'int64') (flat_fw_outputs, flat_bw_outputs), final_state = \ _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) # FIXME : final state is not reshaped! return (fw_outputs, bw_outputs), final_state # def bidirectional_rnn(cell_fw, cell_bw, inputs, # initial_state_fw=None, initial_state_bw=None, # dtype=None, sequence_length=None, scope=None): # flat_inputs = flatten(inputs, 1) # [-1, J, d] # flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') # (flat_fw_outputs, flat_bw_outputs), final_state = \ # _bidirectional_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, # initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, # dtype=dtype, scope=scope) # fw_outputs = reconstruct(flat_fw_outputs, inputs, 1) # bw_outputs = reconstruct(flat_bw_outputs, inputs, 1) # # FIXME : final state is not reshaped! # return (fw_outputs, bw_outputs), final_state # def build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size): # """ # Create the LSTM layers # """ # lstms = [tf.contrib.rnn.BasicLSTMCell(size) for size in lstm_sizes] # # Add dropout to the cell # drops = [tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob_) for lstm in lstms] # # Stack up multiple LSTM layers, for deep learning # cell = tf.contrib.rnn.MultiRNNCell(drops) # # Getting an initial state of all zeros # initial_state = cell.zero_state(batch_size, tf.float32) # lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, embed, initial_state=initial_state)
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major flat_inputs = flatten(inputs, 2) # [-1, J, d] flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') (flat_fw_outputs, flat_bw_outputs), final_state = \ _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) # FIXME : final state is not reshaped! (_, flat_final_hidden_fw), (_, flat_final_hidden_bw) = final_state def my_reconstruct(tensor, ref): tensor_shape = tensor.get_shape().as_list() ref_shape = ref.get_shape().as_list() pre_shape = [ref_shape[i] or tf.shape(ref)[i] for i in range(len(ref_shape) - 2)] keep_shape = tensor_shape[-1:] target_shape = pre_shape + keep_shape out = tf.reshape(tensor, target_shape) return out final_hidden_fw = my_reconstruct(flat_final_hidden_fw, inputs) final_hidden_bw = my_reconstruct(flat_final_hidden_bw, inputs) return (fw_outputs, bw_outputs), (final_hidden_fw, final_hidden_bw)
def my_bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major flat_inputs = my_flatten(inputs) flat_inputs = tf.Print(flat_inputs, [tf.shape(flat_inputs)], message="flat inputs g1 shape:", first_n=5) flat_len = None if sequence_length is None else tf.cast( my_flatten_2(sequence_length), 'int64') flat_len = tf.Print(flat_len, [flat_len], message="flat lebgths of g1:", first_n=5) final_output, final_state = \ _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) return final_output, final_state, flat_len
def my_new_bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major #flat_inputs = flatten(inputs, 2) # [-1, J, d] flat_inputs = my_flatten(inputs) #flat_inputs = tf.reshape(inputs,[tf.shape(inputs)[0],-1,tf.shape(inputs)[3]]) flat_len = sequence_length (flat_fw_outputs, flat_bw_outputs), final_state = \ _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 1) bw_outputs = reconstruct(flat_bw_outputs, inputs, 1) # FIXME : final state is not reshaped! return (fw_outputs, bw_outputs), final_state
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, scope=None): flat_inputs = flatten(inputs, 2) # [-1, seq_len, dim] flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), dtype=tf.int64) (flat_fw_outputs, flat_bw_outputs), final_state = _bidirectional_dynamic_rnn( cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, dtype=tf.float32, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) return (fw_outputs, bw_outputs), final_state
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): assert not time_major flat_inputs = flatten(inputs, 2) # [-1, J, d] flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') (flat_fw_outputs, flat_bw_outputs), final_state = \ _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) # FIXME : final state is not reshaped! return (fw_outputs, bw_outputs), final_state
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): """ :param cell_fw: :param cell_bw: :param inputs: [60, ? , 200] :param sequence_length: shape = [60] :param initial_state_fw: :param initial_state_bw: :param dtype: float :param parallel_iterations: :param swap_memory: false :param time_major: false :param scope: u1 :return: the hidden for each state, both direction, and final state for both direction """ assert not time_major flat_inputs = flatten(inputs, 2) # [-1, J, d] [60, ?, 200] flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') (flat_fw_outputs, flat_bw_outputs), final_state = \ _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, # 200, sequence_length=flat_len, # question mark initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=scope) fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) # FIXME : final state is not reshaped! return (fw_outputs, bw_outputs), final_state
def decode(self, knowledge_rep, masks, is_train, hparams): """ takes in a knowledge representation and output a probability estimation over all paragraph tokens on which token should be the start of the answer span, and which should be the end of the answer span. :param knowledge_rep: it is a representation of the paragraph and question, decided by how you choose to implement the encoder :return: """ p0 = knowledge_rep p_mask, q_mask = masks batch_size = hparams.batch_size input_keep_prob = hparams.input_keep_prob p_len = tf.reduce_sum(tf.cast(p_mask, 'int32'), 1) # [N] q_len = tf.reduce_sum(tf.cast(q_mask, 'int32'), 1) # [N] JX = tf.shape(p_mask)[1] with tf.variable_scope("main"): cell = BasicLSTMCell(self.state_size, state_is_tuple=True) first_cell = SwitchableDropoutWrapper(cell, is_train, input_keep_prob=input_keep_prob) # [N, JX, 2d] (fw_g0, bw_g0), _ = _bidirectional_dynamic_rnn(first_cell, first_cell, p0, p_len, dtype='float', scope='g0') g0 = tf.concat([fw_g0, bw_g0], 2) cell = BasicLSTMCell(self.state_size, state_is_tuple=True) first_cell = SwitchableDropoutWrapper(cell, is_train, input_keep_prob=input_keep_prob) # [N, JX, 2d] (fw_g1, bw_g1), _ = _bidirectional_dynamic_rnn(first_cell, first_cell, g0, p_len, dtype='float', scope='g1') g1 = tf.concat([fw_g1, bw_g1], 2) logits = linear_logits([g1, p0], self.state_size, 0.0, scope='logits1', mask=p_mask, is_train=is_train) # TODO use batch _size a1i = softsel(tf.reshape(g1, [batch_size, JX, 2 * self.state_size]), tf.reshape(logits, [batch_size, JX])) a1i = tf.tile(tf.expand_dims(a1i, 1), [1, JX, 1]) flat_logits1 = tf.reshape(logits, [-1, JX]) flat_yp = tf.nn.softmax(flat_logits1) # [-1, M*JX] yp1 = tf.reshape(flat_yp, [-1, JX]) cell = BasicLSTMCell(self.state_size, state_is_tuple=True) d_cell = SwitchableDropoutWrapper(cell, is_train, input_keep_prob=input_keep_prob) # [N, M, JX, 2d] (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn(d_cell, d_cell, tf.concat([p0, g1, a1i, g1 * a1i], 2), p_len, dtype='float', scope='g2') g2 = tf.concat([fw_g2, bw_g2], 2) logits2 = linear_logits([g2, p0], self.state_size, 0.0, scope='logits2', mask=p_mask, is_train=is_train) flat_logits2 = tf.reshape(logits2, [-1, JX]) flat_yp = tf.nn.softmax(flat_logits2) # [-1, M*JX] yp2 = tf.reshape(flat_yp, [-1, JX]) return (yp1, flat_logits1), (yp2, flat_logits2)
name='enc_embedding') output_embedding = tf.Variable(tf.random_uniform((len(char2numY), embed_size), -1.0, 1.0), name='dec_embedding') date_input_embed = tf.nn.embedding_lookup(input_embedding, inputs) date_output_embed = tf.nn.embedding_lookup(output_embedding, outputs) with tf.variable_scope("encoding") as encoding_scope: lstm_enc = tf.contrib.rnn.BasicLSTMCell(nodes) #encoder_output, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=date_input_embed, dtype=tf.float32) ((encoder_fw_outputs, encoder_bw_outputs), (encoder_fw_final_state, encoder_bw_final_state)) = (_bidirectional_dynamic_rnn( cell_fw=lstm_enc, cell_bw=lstm_enc, inputs=date_input_embed, sequence_length=tf.fill([batch_size], x_seq_length), dtype=tf.float32, time_major=False)) encoder_output = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2) encoder_final_state_c = tf.concat( (encoder_fw_final_state.c, encoder_bw_final_state.c), 1) encoder_final_state_h = tf.concat( (encoder_fw_final_state.h, encoder_bw_final_state.h), 1) encoder_final_state = tf.contrib.rnn.LSTMStateTuple(c=encoder_final_state_c, h=encoder_final_state_h) """ with tf.variable_scope("decoding") as decoding_scope: lstm_dec = tf.contrib.rnn.BasicLSTMCell(nodes) dec_outputs, _ =tf.nn.dynamic_rnn(lstm_dec, inputs=date_output_embed, dtype=tf.float32,initial_state=last_state)