def basic_lstm_model(inputs): print "Loading basic lstm model.." for i in range(self.config.rnn_numLayers): with tf.variable_scope('rnnLayer' + str(i)): lstm_cell = rnn_cell.BasicLSTMCell(self.config.hidden_size) outputs, _ = tf.nn.dynamic_rnn( lstm_cell, inputs, self.ph_seqLen, #(b_sz, tstp, h_sz) dtype=tf.float32, swap_memory=True, scope='badic_lstm_model_layer-' + str(i)) inputs = outputs #b_sz, tstp, h_sz mask = mkMask(self.ph_seqLen, tstp) # b_sz, tstp mask = tf.expand_dims(mask, dim=2) #b_sz, tstp, 1 aggregate_state = reduce_avg(outputs, mask, tf.expand_dims(self.ph_seqLen, 1), dim=-2) #b_sz, h_sz inputs = aggregate_state inputs = tf.reshape(inputs, [-1, self.config.hidden_size]) for i in range(self.config.fnn_numLayers): inputs = rnn.rnn_cell._linear(inputs, self.config.hidden_size, bias=True, scope='fnn_layer-' + str(i)) inputs = tf.nn.tanh(inputs) aggregate_state = inputs logits = rnn.rnn_cell._linear(aggregate_state, self.config.class_num, bias=True, scope='fnn_softmax') return logits
def basic_cbow_model(inputs): mask = mkMask(self.ph_seqLen, tstp) # b_sz, tstp mask = tf.expand_dims(mask, dim=2) #b_sz, tstp, 1 aggregate_state = reduce_avg(inputs, mask, tf.expand_dims(self.ph_seqLen, 1), dim=-2) #b_sz, emb_sz inputs = aggregate_state inputs = tf.reshape(inputs, [-1, self.config.embed_size]) for i in range(self.config.fnn_numLayers): inputs = rnn.rnn_cell._linear(inputs, self.config.embed_size, bias=True, scope='fnn_layer-' + str(i)) inputs = tf.nn.tanh(inputs) aggregate_state = inputs logits = rnn.rnn_cell._linear(aggregate_state, self.config.class_num, bias=True, scope='fnn_softmax') return logits
def flatten_attention(self, in_x, wNum, scope=None): ''' :param in_x: shape(b_sz, wtstp, emb_sz) :param sNum: shape(b_sz, ) :param wNum: shape(b_sz,) :param scope: :return: ''' b_sz, wtstp, _ = tf.unstack(tf.shape(in_x)) emb_sz = int(in_x.get_shape()[-1]) with tf.variable_scope(scope or 'encoding_attention'): with tf.variable_scope('snt_enc'): if self.config.seq_encoder == 'bigru': birnn_wd = self.biGRU(in_x, wNum, self.config.hidden_size, scope='biGRU') elif self.config.seq_encoder == 'bilstm': birnn_wd = self.biLSTM(in_x, wNum, self.config.hidden_size, scope='biLSTM') else: raise ValueError('no such encoder %s' % self.config.seq_encoder) '''shape(b_sz, dim)''' if self.config.attn_mode == 'avg': snt_rep = reduce_avg(birnn_wd, wNum, dim=1) elif self.config.attn_mode == 'attn': snt_rep = self.task_specific_attention( birnn_wd, wNum, int(birnn_wd.get_shape()[-1]), dropout=self.config.dropout, is_train=self.ph_train, scope='attention') elif self.config.attn_mode == 'rout': snt_rep = self.routing_masked( birnn_wd, wNum, int(birnn_wd.get_shape()[-1]), self.config.out_caps_num, iter=self.config.rout_iter, dropout=self.config.dropout, is_train=self.ph_train, scope='attention') elif self.config.attn_mode == 'Rrout': snt_rep = self.reverse_routing_masked( birnn_wd, wNum, int(birnn_wd.get_shape()[-1]), self.config.out_caps_num, iter=self.config.rout_iter, dropout=self.config.dropout, is_train=self.ph_train, scope='attention') else: raise ValueError('no such attn mode %s' % self.config.attn_mode) return snt_rep
def hierachical_attention(self, in_x, sNum, wNum, scope=None): ''' :param in_x: shape(b_sz, ststp, wtstp, emb_sz) :param sNum: shape(b_sz, ) :param wNum: shape(b_sz, ststp) :param scope: :return: ''' b_sz, ststp, wtstp, _ = tf.unstack(tf.shape(in_x)) emb_sz = int(in_x.get_shape()[-1]) with tf.variable_scope(scope or 'hierachical_attention'): flatten_in_x = tf.reshape(in_x, [b_sz * ststp, wtstp, emb_sz]) flatten_wNum = tf.reshape(wNum, [b_sz * ststp]) with tf.variable_scope('sentence_enc'): if self.config.seq_encoder == 'bigru': flatten_birnn_x = self.biGRU(flatten_in_x, flatten_wNum, self.config.hidden_size, scope='biGRU') elif self.config.seq_encoder == 'bilstm': flatten_birnn_x = self.biLSTM(flatten_in_x, flatten_wNum, self.config.hidden_size, scope='biLSTM') else: raise ValueError('no such encoder %s' % self.config.seq_encoder) '''shape(b_sz*sNum, dim)''' if self.config.attn_mode == 'avg': flatten_attn_ctx = reduce_avg(flatten_birnn_x, flatten_wNum, dim=1) elif self.config.attn_mode == 'attn': flatten_attn_ctx = self.task_specific_attention( flatten_birnn_x, flatten_wNum, int(flatten_birnn_x.get_shape()[-1]), dropout=self.config.dropout, is_train=self.ph_train, scope='attention') elif self.config.attn_mode == 'rout': flatten_attn_ctx = self.routing_masked( flatten_birnn_x, flatten_wNum, int(flatten_birnn_x.get_shape()[-1]), self.config.out_caps_num, iter=self.config.rout_iter, dropout=self.config.dropout, is_train=self.ph_train, scope='rout') elif self.config.attn_mode == 'Rrout': flatten_attn_ctx = self.reverse_routing_masked( flatten_birnn_x, flatten_wNum, int(flatten_birnn_x.get_shape()[-1]), self.config.out_caps_num, iter=self.config.rout_iter, dropout=self.config.dropout, is_train=self.ph_train, scope='Rrout') else: raise ValueError('no such attn mode %s' % self.config.attn_mode) snt_dim = int(flatten_attn_ctx.get_shape()[-1]) snt_reps = tf.reshape(flatten_attn_ctx, shape=[b_sz, ststp, snt_dim]) with tf.variable_scope('doc_enc'): if self.config.seq_encoder == 'bigru': birnn_snt = self.biGRU(snt_reps, sNum, self.config.hidden_size, scope='biGRU') elif self.config.seq_encoder == 'bilstm': birnn_snt = self.biLSTM(snt_reps, sNum, self.config.hidden_size, scope='biLSTM') else: raise ValueError('no such encoder %s' % self.config.seq_encoder) '''shape(b_sz, dim)''' if self.config.attn_mode == 'avg': doc_rep = reduce_avg(birnn_snt, sNum, dim=1) elif self.config.attn_mode == 'max': doc_rep = tf.reduce_max(birnn_snt, axis=1) elif self.config.attn_mode == 'attn': doc_rep = self.task_specific_attention( birnn_snt, sNum, int(birnn_snt.get_shape()[-1]), dropout=self.config.dropout, is_train=self.ph_train, scope='attention') elif self.config.attn_mode == 'rout': doc_rep = self.routing_masked( birnn_snt, sNum, int(birnn_snt.get_shape()[-1]), self.config.out_caps_num, iter=self.config.rout_iter, dropout=self.config.dropout, is_train=self.ph_train, scope='attention') elif self.config.attn_mode == 'Rrout': doc_rep = self.reverse_routing_masked( birnn_snt, sNum, int(birnn_snt.get_shape()[-1]), self.config.out_caps_num, iter=self.config.rout_iter, dropout=self.config.dropout, is_train=self.ph_train, scope='attention') else: raise ValueError('no such attn mode %s' % self.config.attn_mode) return doc_rep
def feed_back_lstm(inputs): def feed_back_net(inputs, seq_len, feed_back_steps): ''' Args: inputs: shape(b_sz, tstp, emb_sz) ''' shape_of_input = tf.shape(inputs) b_sz = shape_of_input[0] h_sz = self.config.hidden_size tstp = shape_of_input[1] emb_sz = self.config.embed_size def body(time, prev_output, state_ta): ''' Args: prev_output: previous output shape(b_sz, tstp, hidden_size) ''' prev_output = tf.reshape(prev_output, shape=[-1, h_sz ]) #shape(b_sz*tstp, h_sz) output_linear = tf.nn.rnn_cell._linear( prev_output, output_size=h_sz, #shape(b_sz*tstp, h_sz) bias=False, scope='output_transformer') output_linear = tf.reshape( output_linear, shape=[b_sz, tstp, h_sz]) #shape(b_sz, tstp, h_sz) output_linear = tf.tanh( output_linear) #shape(b_sz, tstp, h_sz) rnn_input = tf.concat(2, [output_linear, inputs], name='concat_output_input' ) #shape(b_sz, tstp, h_sz+emb_sz) cell = tf.nn.rnn_cell.BasicLSTMCell(h_sz) cur_outputs, state = tf.nn.dynamic_rnn(cell, rnn_input, seq_len, dtype=tf.float32, swap_memory=True, time_major=False, scope='encoder') state = tf.concat(1, state) state_ta = state_ta.write(time, state) return time + 1, cur_outputs, state_ta #shape(b_sz, tstp, h_sz) def condition(time, *_): return time < feed_back_steps state_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, clear_after_read=True, size=0) initial_output = tf.zeros(shape=[b_sz, tstp, h_sz], dtype=inputs.dtype, name='initial_output') time = tf.constant(0, dtype=tf.int32) _, outputs, state_ta = tf.while_loop( condition, body, [time, initial_output, state_ta], swap_memory=True) final_state = state_ta.read(state_ta.size() - 1) return final_state, outputs _, outputs = feed_back_net(inputs, self.ph_seqLen, feed_back_steps=10) mask = mkMask(self.ph_seqLen, tstp) # b_sz, tstp mask = tf.expand_dims(mask, dim=2) #b_sz, tstp, 1 aggregate_state = reduce_avg(outputs, mask, tf.expand_dims(self.ph_seqLen, 1), dim=-2) #b_sz, h_sz inputs = aggregate_state inputs = tf.reshape(inputs, [-1, self.config.hidden_size]) for i in range(self.config.fnn_numLayers): inputs = rnn.rnn_cell._linear(inputs, self.config.hidden_size, bias=True, scope='fnn_layer-' + str(i)) inputs = tf.nn.tanh(inputs) aggregate_state = inputs logits = rnn.rnn_cell._linear(aggregate_state, self.config.class_num, bias=True, scope='fnn_softmax') return logits