def add_model(self): """ input_tensor #(batch_size, num_sentence, embed_size) input_len #(batch_size) """ b_sz = tf.shape(self.encoder_input)[0] tstp_enc = tf.shape(self.encoder_input)[1] tstp_dec = tf.shape(self.ph_decoder_label)[1] enc_in = self.encoder_input # shape(b_sz, tstp_enc, s_emb_sz) enc_len = self.ph_input_encoder_len # shape(b_sz,) dec_len = self.ph_input_encoder_len # shape(b_sz,) order_idx = self.ph_decoder_label # shape(b_sz, tstp_dec) cell_dec = rnn_cell.BasicLSTMCell(self.config.h_dec_sz) with tf.variable_scope('add_model') as vscope: out_logits = self.train_module( # shape(b_sz, tstp_dec, tstp_enc) cell_dec, enc_in, enc_len, dec_len, order_idx, scope='decoder_train') vscope.reuse_variables() predict_idx = self.decoder_test( cell_dec, enc_in, enc_len, dec_len, scope='decoder_train') # shape(b_sz, tstp_dec) train_loss, valid_loss = self.add_loss_op(out_logits, order_idx, dec_len) return train_loss, valid_loss, predict_idx
def basic_lstm_model(inputs): print "Loading basic lstm model.." for i in range(self.config.rnn_numLayers): with tf.variable_scope('rnnLayer' + str(i)): lstm_cell = rnn_cell.BasicLSTMCell(self.config.hidden_size) outputs, _ = tf.nn.dynamic_rnn( lstm_cell, inputs, self.ph_seqLen, #(b_sz, tstp, h_sz) dtype=tf.float32, swap_memory=True, scope='basic_lstm_model_layer-' + str(i)) inputs = outputs #b_sz, tstp, h_sz mask = TfUtils.mkMask(self.ph_seqLen, tstp) # b_sz, tstp mask = tf.expand_dims(mask, axis=2) #b_sz, tstp, 1 aggregate_state = TfUtils.reduce_avg(outputs, self.ph_seqLen, dim=1) #b_sz, h_sz inputs = aggregate_state inputs = tf.reshape(inputs, [-1, self.config.hidden_size]) for i in range(self.config.fnn_numLayers): inputs = TfUtils.linear(inputs, self.config.hidden_size, bias=True, scope='fnn_layer-' + str(i)) inputs = tf.nn.tanh(inputs) aggregate_state = inputs logits = TfUtils.linear(aggregate_state, self.config.class_num, bias=True, scope='fnn_softmax') return logits
def lstm_sentence_rep(input): with tf.variable_scope('lstm_sentence_rep_scope') as scope: input = tf.reshape(input, shape=[b_sz * tstps_en, -1, emb_sz ]) #(b_sz*tstps_en, len_sen, emb_sz) length = tf.reshape(self.ph_input_encoder_sentence_len, shape=[-1]) #(b_sz*tstps_en) lstm_cell = rnn_cell.BasicLSTMCell(h_sz) """tup(shape(b_sz*tstp_enc, len_sen, h_sz))""" rep_out, _ = tf.nn.bidirectional_dynamic_rnn( # tup(shape(b_sz*tstp_enc, len_sen, h_sz)) lstm_cell, lstm_cell, input, length, dtype=tf.float32, swap_memory=True, time_major=False, scope='sentence_encode') rep_out = tf.concat( axis=2, values=rep_out) #(b_sz*tstps_en, len_sen, h_sz*2) rep_out = TfUtils.reduce_avg( rep_out, length, dim=1) # shape(b_sz*tstps_en, h_sz*2) output = tf.reshape(rep_out, shape=[b_sz, tstps_en, 2 * h_sz ]) #(b_sz, tstps_en, h_sz*2) return output, None, None
def train_module(self, cell_dec, encoder_inputs, enc_lengths, dec_lengths, order_index, scope=None): ''' Args: cell_dec : lstm cell object, a configuration encoder_inputs : shape(b_sz, tstp_enc, s_emb_sz) enc_lengths : shape(b_sz,), encoder input lengths dec_lengths : shape(b_sz), decoder input lengths order_index : shape(b_sz, tstp_dec), decoder label ''' small_num = -np.Inf input_shape = tf.shape(encoder_inputs) b_sz = input_shape[0] tstp_enc = input_shape[1] tstp_dec = tstp_enc # since no noise, time step of decoder should be the same as encoder h_enc_sz = self.config.h_enc_sz h_dec_sz = self.config.h_dec_sz s_emb_sz = np.int(encoder_inputs.get_shape() [2]) # should be a python-determined number cell_enc = rnn_cell.BasicLSTMCell(self.config.h_enc_sz) def enc(dec_h, in_x, lengths, fake_call=False): ''' Args: dec_h: shape(b_sz, tstp_dec, h_dec_sz) in_x: shape(b_sz, tstp_enc, s_emb_sz) lengths: shape(b_sz) Returns: res: shape(b_sz, tstp_dec, tstp_enc, Ptr_sz) ''' def func_f(in_x, enc_h, in_h_hat, fake_call=False): ''' Args: in_x: shape(b_sz, tstp_dec, tstp_enc, enc_emb_sz) in_h: shape(b_sz, tstp_dec, tstp_enc, h_enc_sz*2) Returns: res: shape(b_sz, tstp_dec, tstp_enc, enc_emb_sz+h_enc_sz*2) ''' if fake_call: return s_emb_sz + h_enc_sz * 4 in_x_sz = int(in_x.get_shape()[-1]) in_h_sz = int(enc_h.get_shape()[-1]) if not in_x_sz: assert ValueError('last dimension of the first' + ' arg should be known, while got %s' % (str(type(in_x_sz)))) if not in_h_sz: assert ValueError('last dimension of the second' + ' arg should be known, while got %s' % (str(type(in_h_sz)))) enc_in_ex = tf.expand_dims( in_x, 1) # shape(b_sz, 1, tstp_enc, s_emb_sz) enc_in = tf.tile( enc_in_ex, # shape(b_sz, tstp_dec, tstp_enc, s_emb_sz) [1, tstp_dec, 1, 1]) res = tf.concat(axis=3, values=[enc_in, enc_h, in_h_hat]) return res # shape(b_sz, tstp_dec, tstp_enc, enc_emb_sz+h_enc_sz*4) def attend(enc_h, enc_len): ''' Args: enc_h: shape(b_sz, tstp_dec, tstp_enc, h_enc_sz*2) enc_len: shape(b_sz) ''' enc_len = tf.expand_dims(enc_len, 1) # shape(b_sz, 1) attn_enc_len = tf.tile(enc_len, [1, tstp_dec]) attn_enc_len = tf.reshape(attn_enc_len, [b_sz * tstp_dec]) attn_enc_h = tf.reshape( enc_h, # shape(b_sz*tstp_dec, tstp_enc, h_enc_sz*2) [b_sz * tstp_dec, tstp_enc, np.int(enc_h.get_shape()[-1])]) attn_out = TfUtils.self_attn( # shape(b_sz*tstp_dec, tstp_enc, h_enc_sz*2) attn_enc_h, attn_enc_len) h_hat = tf.reshape( attn_out, # shape(b_sz, tstp_dec, tstp_enc, h_enc_sz*2) [ b_sz, tstp_dec, tstp_enc, np.int(attn_out.get_shape()[-1]) ]) return h_hat if fake_call: return func_f(None, None, None, fake_call=True) def get_lstm_in_len(): inputs = func_enc_input( dec_h, in_x) # shape(b_sz, tstp_dec, tstp_enc, enc_emb_sz) enc_emb_sz = np.int(inputs.get_shape()[-1]) enc_in = tf.reshape( inputs, shape=[b_sz * tstp_dec, tstp_enc, enc_emb_sz]) enc_len = tf.expand_dims(lengths, 1) # shape(b_sz, 1) enc_len = tf.tile(enc_len, [1, tstp_dec]) # shape(b_sz, tstp_dec) enc_len = tf.reshape( enc_len, [b_sz * tstp_dec]) # shape(b_sz*tstp_dec,) return enc_in, enc_len '''shape(b_sz*tstp_dec, tstp_enc, enc_emb_sz), shape(b_sz*tstp_dec)''' enc_in, enc_len = get_lstm_in_len() '''tup(shpae(b_sz*tstp_dec, tstp_enc, h_enc_sz))''' lstm_out, _ = tf.nn.bidirectional_dynamic_rnn(cell_enc, cell_enc, enc_in, enc_len, swap_memory=True, dtype=tf.float32, scope='sent_encoder') enc_out = tf.concat( axis=2, values=lstm_out) # shape(b_sz*tstp_dec, tstp_enc, h_enc_sz*2) enc_out = tf.reshape( enc_out, # shape(b_sz, tstp_dec, tstp_enc, h_enc_sz*2) shape=[b_sz, tstp_dec, tstp_enc, h_enc_sz * 2]) enc_out_hat = attend(enc_out, lengths) res = func_f(in_x, enc_out, enc_out_hat) return res # shape(b_sz, tstp_dec, tstp_enc, Ptr_sz) def func_enc_input(dec_h, enc_input, fake_call=False): ''' Args: enc_input: encoder input, shape(b_sz, tstp_enc, s_emb_sz) dec_h: decoder hidden state, shape(b_sz, tstp_dec, h_dec_sz) Returns: output: shape(b_sz, tstp_dec, tstp_enc, s_emb_sz+h_dec_sz) ''' enc_emb_sz = s_emb_sz + h_dec_sz if fake_call: return enc_emb_sz dec_h_ex = tf.expand_dims(dec_h, 2) # shape(b_sz, tstp_dec, 1, h_dec_sz) dec_h_tile = tf.tile( dec_h_ex, # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz) [1, 1, tstp_enc, 1]) enc_in_ex = tf.expand_dims(enc_input, 1) # shape(b_sz, 1, tstp_enc, s_emb_sz) enc_in_tile = tf.tile( enc_in_ex, # shape(b_sz, tstp_dec, tstp_enc, s_emb_sz) [1, tstp_dec, 1, 1]) output = tf.concat( axis=3, # shape(b_sz, tstp_dec, tstp_enc, s_emb_sz+h_dec_sz) values=[enc_in_tile, dec_h_tile]) output = tf.reshape( output, shape=[b_sz, tstp_dec, tstp_enc, s_emb_sz + h_dec_sz]) return output # shape(b_sz, tstp_dec, tstp_enc, s_emb_sz+h_dec_sz) def func_point_logits(dec_h, enc_ptr, enc_len): ''' Args: dec_h : shape(b_sz, tstp_dec, h_dec_sz) enc_ptr : shape(b_sz, tstp_dec, tstp_enc, Ptr_sz) enc_len : shape(b_sz,) ''' dec_h_ex = tf.expand_dims( dec_h, axis=2) # shape(b_sz, tstp_dec, 1, h_dec_sz) dec_h_ex = tf.tile(dec_h_ex, [1, 1, tstp_enc, 1 ]) # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz) linear_concat = tf.concat(axis=3, values=[ dec_h_ex, enc_ptr ]) # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz+ Ptr_sz) point_linear = TfUtils.last_dim_linear( # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz) linear_concat, output_size=h_dec_sz, bias=False, scope='Ptr_W') point_v = TfUtils.last_dim_linear( # shape(b_sz, tstp_dec, tstp_enc, 1) tf.tanh(point_linear), output_size=1, bias=False, scope='Ptr_V') point_logits = tf.squeeze( point_v, axis=[3]) # shape(b_sz, tstp_dec, tstp_enc) enc_len = tf.expand_dims(enc_len, 1) # shape(b_sz, 1) enc_len = tf.tile(enc_len, [1, tstp_dec]) # shape(b_sz, tstp_dec) mask = TfUtils.mkMask( enc_len, maxLen=tstp_enc) # shape(b_sz, tstp_dec, tstp_enc) point_logits = tf.where( mask, point_logits, # shape(b_sz, tstp_dec, tstp_enc) tf.ones_like(point_logits) * small_num) return point_logits def get_initial_state(hidden_sz): ''' Args: hidden_sz: must be a python determined number ''' avg_in_x = TfUtils.reduce_avg( encoder_inputs, # shape(b_sz, s_emb_sz) enc_lengths, dim=1) state = TfUtils.linear( avg_in_x, hidden_sz, # shape(b_sz, hidden_sz) bias=False, scope='initial_transformation') state = rnn_cell.LSTMStateTuple(state, tf.zeros_like(state)) return state def get_bos(emb_sz): with tf.variable_scope('bos_scope') as vscope: try: ret = tf.get_variable(name='bos', shape=[1, emb_sz], dtype=tf.float32) except: vscope.reuse_variables() ret = tf.get_variable(name='bos', shape=[1, emb_sz], dtype=tf.float32) ret_bos = tf.tile(ret, [b_sz, 1]) return ret_bos def decoder(): def get_dec_in(): dec_in = TfUtils.batch_embed_lookup( encoder_inputs, order_index) # shape(b_sz, tstp_dec, s_emb_sz) bos = get_bos(s_emb_sz) # shape(b_sz, s_emb_sz) bos = tf.expand_dims(bos, 1) # shape(b_sz, 1, s_smb_sz) dec_in = tf.concat( axis=1, values=[bos, dec_in]) # shape(b_sz, tstp_dec+1, s_emb_sz) dec_in = dec_in[:, :-1, :] # shape(b_sz, tstp_dec, s_emb_sz) return dec_in dec_in = get_dec_in() # shape(b_sz, tstp_dec, s_emb_sz) initial_state = get_initial_state( h_dec_sz) # shape(b_sz, h_dec_sz) dec_out, _ = tf.nn.dynamic_rnn( cell_dec, dec_in, # shape(b_sz, tstp_dec, h_dec_sz) dec_lengths, initial_state=initial_state, swap_memory=True, dtype=tf.float32, scope=scope) with tf.variable_scope(scope): enc_out = enc( dec_out, # shape(b_sz, tstp_dec, tstp_enc, Ptr_sz) encoder_inputs, enc_lengths) point_logits = func_point_logits( dec_out, enc_out, enc_lengths) # shape(b_sz, tstp_dec, tstp_enc) return point_logits point_logits = decoder() # shape(b_sz, tstp_dec, tstp_enc) return point_logits
def decoder_test(self, cell_dec, encoder_inputs, enc_lengths, dec_lengths, scope=None): ''' Args: cell_dec : lstm cell object, a configuration encoder_inputs : shape(b_sz, tstp_enc, s_emb_sz) enc_lengths : shape(b_sz,), encoder input lengths dec_lengths : shape(b_sz), decoder input lengths order_index : shape(b_sz, tstp_dec), decoder label ''' small_num = -np.Inf input_shape = tf.shape(encoder_inputs) b_sz = input_shape[0] tstp_enc = input_shape[1] tstp_dec = tstp_enc # since no noise, time step of decoder should be the same as encoder h_enc_sz = self.config.h_enc_sz h_dec_sz = self.config.h_dec_sz s_emb_sz = np.int(encoder_inputs.get_shape() [2]) # should be a python-determined number # dec_emb_sz not determined cell_enc = rnn_cell.BasicLSTMCell(self.config.h_enc_sz) def enc(dec_h, in_x, lengths, fake_call=False): ''' Args: inputs: shape(b_sz, tstp_enc, enc_emb_sz) ''' def func_f(in_x, in_h, in_h_hat, fake_call=False): if fake_call: return s_emb_sz + h_enc_sz * 4 in_x_sz = int(in_x.get_shape()[-1]) in_h_sz = int(in_h.get_shape()[-1]) if not in_x_sz: assert ValueError('last dimension of the first' + ' arg should be known, while got %s' % (str(type(in_x_sz)))) if not in_h_sz: assert ValueError('last dimension of the second' + ' arg should be known, while got %s' % (str(type(in_h_sz)))) res = tf.concat(axis=2, values=[in_x, in_h, in_h_hat]) return res if fake_call: return func_f(None, None, None, fake_call=True) inputs = func_enc_input(dec_h, in_x) lstm_out, _ = tf.nn.bidirectional_dynamic_rnn(cell_enc, cell_enc, inputs, lengths, swap_memory=True, dtype=tf.float32, scope='sent_encoder') enc_out = tf.concat( axis=2, values=lstm_out) # shape(b_sz, tstp_enc, h_enc_sz*2) enc_out = tf.reshape(enc_out, [b_sz, tstp_enc, h_enc_sz * 2]) enc_out_hat = TfUtils.self_attn(enc_out, lengths) res = func_f(in_x, enc_out, enc_out_hat) return res # shape(b_sz, tstp_enc, dec_emb_sz) def func_enc_input(dec_h, enc_input, fake_call=False): ''' Args: enc_input: encoder input, shape(b_sz, tstp_enc, s_emb_sz) dec_h: decoder hidden state, shape(b_sz, h_dec_sz) ''' enc_emb_sz = s_emb_sz + h_dec_sz if fake_call: return enc_emb_sz dec_h_ex = tf.expand_dims(dec_h, 1) # shape(b_sz, 1, h_dec_sz) dec_h_tile = tf.tile(dec_h_ex, [1, tstp_enc, 1]) output = tf.concat(axis=2, values=[ enc_input, dec_h_tile ]) # shape(b_sz, tstp_enc, s_emb_sz + h_dec_sz) output = tf.reshape(output, shape=[b_sz, tstp_enc, s_emb_sz + h_dec_sz]) return output # shape(b_sz, tstp_enc, s_emb_sz + h_dec_sz) enc_emb_sz = func_enc_input(None, None, fake_call=True) dec_emb_sz = enc(None, None, None, fake_call=True) def func_point_logits(dec_h, enc_e, enc_len): ''' Args: dec_h : shape(b_sz, h_dec_sz) enc_e : shape(b_sz, tstp_enc, dec_emb_sz) enc_len : shape(b_sz,) ''' dec_h_ex = tf.expand_dims(dec_h, axis=1) # shape(b_sz, 1, h_dec_sz) dec_h_ex = tf.tile( dec_h_ex, [1, tstp_enc, 1]) # shape(b_sz, tstp_enc, h_dec_sz) linear_concat = tf.concat(axis=2, values=[ dec_h_ex, enc_e ]) # shape(b_sz, tstp_enc, h_dec_sz+ dec_emb_sz) point_linear = TfUtils.last_dim_linear( # shape(b_sz, tstp_enc, h_dec_sz) linear_concat, output_size=h_dec_sz, bias=False, scope='Ptr_W') point_v = TfUtils.last_dim_linear( # shape(b_sz, tstp_enc, 1) tf.tanh(point_linear), output_size=1, bias=False, scope='Ptr_V') point_logits = tf.squeeze(point_v, axis=[2]) # shape(b_sz, tstp_enc) mask = TfUtils.mkMask(enc_len, maxLen=tstp_enc) # shape(b_sz, tstp_enc) point_logits = tf.where(mask, point_logits, tf.ones_like(point_logits) * small_num) # shape(b_sz, tstp_enc) return point_logits def func_point_idx(dec_h, enc_e, enc_len, hit_mask): ''' Args: hit_mask: shape(b_sz, tstp_enc) ''' logits = func_point_logits(dec_h, enc_e, enc_len) # shape(b_sz, tstp_enc) prob = tf.nn.softmax(logits) prob = tf.where(hit_mask, tf.zeros_like(prob), prob, name='mask_hit_pos') idx = tf.cast(tf.arg_max(prob, dimension=1), dtype=tf.int32) # shape(b_sz,) type of int32 return idx # shape(b_sz,) def get_bos(emb_sz): with tf.variable_scope('bos_scope') as vscope: try: ret = tf.get_variable(name='bos', shape=[1, emb_sz], dtype=tf.float32) except: vscope.reuse_variables() ret = tf.get_variable(name='bos', shape=[1, emb_sz], dtype=tf.float32) ret_bos = tf.tile(ret, [b_sz, 1]) return ret_bos def get_initial_state(hidden_sz): ''' Args: hidden_sz: must be a python determined number ''' avg_in_x = TfUtils.reduce_avg( encoder_inputs, # shape(b_sz, s_emb_sz) enc_lengths, dim=1) state = TfUtils.linear( avg_in_x, hidden_sz, # shape(b_sz, hidden_sz) bias=False, scope='initial_transformation') state = rnn_cell.LSTMStateTuple(state, tf.zeros_like(state)) return state bos = get_bos(s_emb_sz) # shape(b_sz, s_emb_sz) init_state = get_initial_state(h_dec_sz) def loop_fn(time, cell_output, cell_state, hit_mask): """ Args: cell_output: shape(b_sz, h_dec_sz) ==> d cell_state: tup(shape(b_sz, h_dec_sz)) pointer_logits_ta: pointer logits tensorArray hit_mask: shape(b_sz, tstp_enc) """ if cell_output is None: # time == 0 next_cell_state = init_state next_input = bos # shape(b_sz, dec_emb_sz) next_idx = tf.zeros(shape=[b_sz], dtype=tf.int32) # shape(b_sz, tstp_enc) elements_finished = tf.zeros(shape=[b_sz], dtype=tf.bool, name='elem_finished') next_hit_mask = tf.zeros(shape=[b_sz, tstp_enc], dtype=tf.bool, name='hit_mask') else: next_cell_state = cell_state encoder_e = enc( cell_output, encoder_inputs, enc_lengths) # shape(b_sz, tstp_enc, dec_emb_sz) next_idx = func_point_idx(cell_output, encoder_e, enc_lengths, hit_mask) # shape(b_sz,) cur_hit_mask = tf.one_hot( next_idx, on_value=True, # shape(b_sz, tstp_enc) off_value=False, depth=tstp_enc, dtype=tf.bool) next_hit_mask = tf.logical_or( hit_mask, cur_hit_mask, # shape(b_sz, tstp_enc) name='next_hit_mask') next_input = TfUtils.batch_embed_lookup( encoder_inputs, next_idx) # shape(b_sz, s_emb_sz) elements_finished = (time >= dec_lengths) # shape(b_sz,) return (elements_finished, next_input, next_cell_state, next_hit_mask, next_idx) emit_idx_ta, _ = myRNN.train_rnn(cell_dec, loop_fn, scope=scope) output_idx = emit_idx_ta.stack() # shape(tstp_dec, b_sz) output_idx = tf.transpose(output_idx, perm=[1, 0]) # shape(b_sz, tstp_dec) return output_idx # shape(b_sz, tstp_dec)
def add_logits_op2(self): """利用BLSTM生成结果,batch中每个句子的每个单词都有一个结果,一个结果是n维的变量,n大小为类别的数目""" with tf.variable_scope('Premise_encoder'): lstm_cell = rnn_cell.BasicLSTMCell(hidden_size_lstm) lstm_cell = rnn_cell.DropoutWrapper(lstm_cell, input_keep_prob=self.dropout, output_keep_prob=self.dropout) Premise_out, Premise_state = tf.nn.bidirectional_dynamic_rnn( cell_fw=lstm_cell, cell_bw=lstm_cell, inputs=self.seq1_word_embeddings, sequence_length=self.sequence1_lengths, dtype=tf.float32, swap_memory=True) Premise_output_fw, Premise_output_bw = Premise_out Premise_states_fw, Premise_states_bw = Premise_state Premise_out = tf.concat(Premise_out, 2) Premise_state = tf.concat(Premise_state, 2) with tf.variable_scope('Hypothesis_encoder'): lstm_cell = rnn_cell.BasicLSTMCell(hidden_size_lstm) lstm_cell = rnn_cell.DropoutWrapper(lstm_cell, input_keep_prob=self.dropout, output_keep_prob=self.dropout) Hypo_out, Hypo_state = tf.nn.bidirectional_dynamic_rnn( cell_fw=lstm_cell, cell_bw=lstm_cell, inputs=self.seq2_word_embeddings, sequence_length=self.sequence2_lengths, # initial_state_fw=Premise_states_fw, # initial_state_bw=Premise_states_bw, dtype=tf.float32, swap_memory=True) print('before=', np.shape(Hypo_state[1])) Hypo_out = tf.concat(Hypo_out, 2) Hypo_state = tf.concat(Hypo_state, 2) def w2w_attn(Premise_out, Hypo_out, seqLen_Premise, seqLen_Hypo, scope=None): with tf.variable_scope(scope or 'Attn_layer'): attn_cell = AttnCell(196 * 2, Premise_out, seqLen_Premise) attn_cell = rnn_cell.DropoutWrapper( attn_cell, input_keep_prob=self.dropout, output_keep_prob=self.dropout) _, r_state = tf.nn.dynamic_rnn(attn_cell, Hypo_out, seqLen_Hypo, dtype=Hypo_out.dtype, swap_memory=True) return r_state r_L = w2w_attn(Premise_out, Hypo_out, self.sequence1_lengths, self.sequence2_lengths, scope='w2w_attention') hypo_state1 = tf.reshape(Hypo_state[1], [-1, 392]) hypo_state1 = tf.nn.dropout(hypo_state1, 0.5) print('***********', np.shape(r_L)) print('***********', np.shape(hypo_state1)) h_star = tf.tanh( linear( [r_L, hypo_state1], # shape (b_sz, h_sz) 392, bias=False, scope='linear_trans')) input_fully = h_star output = tf.nn.dropout(input_fully, self.dropout) W = tf.get_variable('W', dtype=tf.float32, shape=[hidden_size_lstm * 2, ntags]) b = tf.get_variable('b', dtype=tf.float32, shape=[ntags], initializer=tf.zeros_initializer()) pred = tf.matmul(output, W) + b logits = tf.reshape(pred, [-1, ntags]) logits = tf.nn.softmax(logits) self.logits = logits ''' for i in range(2): with tf.variable_scope('fully_connect_'+str(i)): logits = tf.contrib.layers.fully_connected( input_fully, 300 * 2, activation_fn=None) input_fully = tf.tanh(logits) with tf.name_scope('Softmax'): logits = tf.contrib.layers.fully_connected( input_fully, self.config.class_num, activation_fn=None) self.logits = logits ''' ''' output1 = attention((output_fw12, output_bw12), attention_size, return_alphas=False) # output_fw21 = tf.concat([output_fw2, output_fw1], axis=1) # output_bw21 = tf.concat([output_bw2, output_bw1], axis=1) output_fw21 = output_fw2 output_bw21 = output_bw2 output2 = attention((output_fw21, output_bw21), attention_size, return_alphas=False) # output = output1 + output2 print('output1=', np.shape(output1)) print('output2=', np.shape(output2)) output = tf.concat([output1, output2], axis=1) output = tf.nn.dropout(output, self.dropout) # dropout print('shape of output=', np.shape(output)) # 接下来构造映射层 # W = tf.get_variable('W', dtype=tf.float32, # shape=[hidden_size_lstm*2, ntags]) W = tf.get_variable('W', dtype=tf.float32, shape=[hidden_size_lstm * 4, ntags]) b = tf.get_variable('b', dtype=tf.float32, shape=[ntags], initializer=tf.zeros_initializer()) pred = tf.matmul(output, W) + b logits = tf.reshape(pred, [-1, ntags]) logits = tf.nn.softmax(logits) self.logits = logits ''' pass
def add_model(self, input_x1, input_x2, seqLen_x1, seqLen_x2): ''' dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): ''' with tf.variable_scope('Premise_encoder'): lstm_cell = rnn_cell.BasicLSTMCell(self.config.hidden_size) lstm_cell = rnn_cell.DropoutWrapper( lstm_cell, input_keep_prob=self.config.dropout, output_keep_prob=self.config.dropout) Premise_out, Premise_state = tf.nn.dynamic_rnn( cell=lstm_cell, inputs=input_x1, sequence_length=seqLen_x1, dtype=tf.float32, swap_memory=True) with tf.variable_scope('Hypothesis_encoder'): lstm_cell = rnn_cell.BasicLSTMCell(self.config.hidden_size) lstm_cell = rnn_cell.DropoutWrapper( lstm_cell, input_keep_prob=self.config.dropout, output_keep_prob=self.config.dropout) Hypo_out, Hypo_state = tf.nn.dynamic_rnn( cell=lstm_cell, inputs=input_x2, sequence_length=seqLen_x2, initial_state=Premise_state, swap_memory=True) def w2w_attn(Premise_out, Hypo_out, seqLen_Premise, seqLen_Hypo, scope=None): with tf.variable_scope(scope or 'Attn_layer'): attn_cell = AttnCell(self.config.hidden_size, Premise_out, seqLen_Premise) attn_cell = rnn_cell.DropoutWrapper( attn_cell, input_keep_prob=self.config.dropout, output_keep_prob=self.config.dropout) _, r_state = tf.nn.dynamic_rnn(attn_cell, Hypo_out, seqLen_Hypo, dtype=Hypo_out.dtype, swap_memory=True) return r_state r_L = w2w_attn(Premise_out, Hypo_out, seqLen_x1, seqLen_x2, scope='w2w_attention') h_star = tf.tanh( linear( [r_L, Hypo_state[1]], # shape (b_sz, h_sz) self.config.hidden_size, bias=False, scope='linear_trans')) input_fully = h_star for i in range(self.config.fnn_layers): with tf.variable_scope('fully_connect_' + str(i)): logits = tf.contrib.layers.fully_connected( input_fully, self.config.hidden_size * 2, activation_fn=None) input_fully = tf.tanh(logits) with tf.name_scope('Softmax'): logits = tf.contrib.layers.fully_connected(input_fully, self.config.class_num, activation_fn=None) return logits