def basic_lstm_model(inputs):
            print "Loading basic lstm model.."
            for i in range(self.config.rnn_numLayers):
                with tf.variable_scope('rnnLayer' + str(i)):
                    lstm_cell = rnn_cell.BasicLSTMCell(self.config.hidden_size)
                    outputs, _ = tf.nn.dynamic_rnn(
                        lstm_cell,
                        inputs,
                        self.ph_seqLen,  #(b_sz, tstp, h_sz)
                        dtype=tf.float32,
                        swap_memory=True,
                        scope='badic_lstm_model_layer-' + str(i))
                    inputs = outputs  #b_sz, tstp, h_sz
            mask = mkMask(self.ph_seqLen, tstp)  # b_sz, tstp
            mask = tf.expand_dims(mask, dim=2)  #b_sz, tstp, 1

            aggregate_state = reduce_avg(outputs,
                                         mask,
                                         tf.expand_dims(self.ph_seqLen, 1),
                                         dim=-2)  #b_sz, h_sz
            inputs = aggregate_state
            inputs = tf.reshape(inputs, [-1, self.config.hidden_size])

            for i in range(self.config.fnn_numLayers):
                inputs = rnn.rnn_cell._linear(inputs,
                                              self.config.hidden_size,
                                              bias=True,
                                              scope='fnn_layer-' + str(i))
                inputs = tf.nn.tanh(inputs)
            aggregate_state = inputs
            logits = rnn.rnn_cell._linear(aggregate_state,
                                          self.config.class_num,
                                          bias=True,
                                          scope='fnn_softmax')
            return logits
        def basic_cbow_model(inputs):
            mask = mkMask(self.ph_seqLen, tstp)  # b_sz, tstp
            mask = tf.expand_dims(mask, dim=2)  #b_sz, tstp, 1

            aggregate_state = reduce_avg(inputs,
                                         mask,
                                         tf.expand_dims(self.ph_seqLen, 1),
                                         dim=-2)  #b_sz, emb_sz
            inputs = aggregate_state
            inputs = tf.reshape(inputs, [-1, self.config.embed_size])

            for i in range(self.config.fnn_numLayers):
                inputs = rnn.rnn_cell._linear(inputs,
                                              self.config.embed_size,
                                              bias=True,
                                              scope='fnn_layer-' + str(i))
                inputs = tf.nn.tanh(inputs)
            aggregate_state = inputs
            logits = rnn.rnn_cell._linear(aggregate_state,
                                          self.config.class_num,
                                          bias=True,
                                          scope='fnn_softmax')
            return logits
Example #3
0
    def flatten_attention(self, in_x, wNum, scope=None):
        '''

        :param in_x: shape(b_sz, wtstp, emb_sz)
        :param sNum: shape(b_sz, )
        :param wNum: shape(b_sz,)
        :param scope:
        :return:
        '''
        b_sz, wtstp, _ = tf.unstack(tf.shape(in_x))
        emb_sz = int(in_x.get_shape()[-1])
        with tf.variable_scope(scope or 'encoding_attention'):

            with tf.variable_scope('snt_enc'):
                if self.config.seq_encoder == 'bigru':
                    birnn_wd = self.biGRU(in_x,
                                          wNum,
                                          self.config.hidden_size,
                                          scope='biGRU')
                elif self.config.seq_encoder == 'bilstm':
                    birnn_wd = self.biLSTM(in_x,
                                           wNum,
                                           self.config.hidden_size,
                                           scope='biLSTM')
                else:
                    raise ValueError('no such encoder %s' %
                                     self.config.seq_encoder)
                '''shape(b_sz, dim)'''
                if self.config.attn_mode == 'avg':
                    snt_rep = reduce_avg(birnn_wd, wNum, dim=1)
                elif self.config.attn_mode == 'attn':
                    snt_rep = self.task_specific_attention(
                        birnn_wd,
                        wNum,
                        int(birnn_wd.get_shape()[-1]),
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='attention')
                elif self.config.attn_mode == 'rout':
                    snt_rep = self.routing_masked(
                        birnn_wd,
                        wNum,
                        int(birnn_wd.get_shape()[-1]),
                        self.config.out_caps_num,
                        iter=self.config.rout_iter,
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='attention')
                elif self.config.attn_mode == 'Rrout':
                    snt_rep = self.reverse_routing_masked(
                        birnn_wd,
                        wNum,
                        int(birnn_wd.get_shape()[-1]),
                        self.config.out_caps_num,
                        iter=self.config.rout_iter,
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='attention')
                else:
                    raise ValueError('no such attn mode %s' %
                                     self.config.attn_mode)
        return snt_rep
    def hierachical_attention(self, in_x, sNum, wNum, scope=None):
        '''

        :param in_x: shape(b_sz, ststp, wtstp, emb_sz)
        :param sNum: shape(b_sz, )
        :param wNum: shape(b_sz, ststp)
        :param scope:
        :return:
        '''
        b_sz, ststp, wtstp, _ = tf.unstack(tf.shape(in_x))
        emb_sz = int(in_x.get_shape()[-1])
        with tf.variable_scope(scope or 'hierachical_attention'):
            flatten_in_x = tf.reshape(in_x, [b_sz * ststp, wtstp, emb_sz])
            flatten_wNum = tf.reshape(wNum, [b_sz * ststp])

            with tf.variable_scope('sentence_enc'):
                if self.config.seq_encoder == 'bigru':
                    flatten_birnn_x = self.biGRU(flatten_in_x,
                                                 flatten_wNum,
                                                 self.config.hidden_size,
                                                 scope='biGRU')
                elif self.config.seq_encoder == 'bilstm':
                    flatten_birnn_x = self.biLSTM(flatten_in_x,
                                                  flatten_wNum,
                                                  self.config.hidden_size,
                                                  scope='biLSTM')
                else:
                    raise ValueError('no such encoder %s' %
                                     self.config.seq_encoder)
                '''shape(b_sz*sNum, dim)'''
                if self.config.attn_mode == 'avg':
                    flatten_attn_ctx = reduce_avg(flatten_birnn_x,
                                                  flatten_wNum,
                                                  dim=1)
                elif self.config.attn_mode == 'attn':
                    flatten_attn_ctx = self.task_specific_attention(
                        flatten_birnn_x,
                        flatten_wNum,
                        int(flatten_birnn_x.get_shape()[-1]),
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='attention')
                elif self.config.attn_mode == 'rout':
                    flatten_attn_ctx = self.routing_masked(
                        flatten_birnn_x,
                        flatten_wNum,
                        int(flatten_birnn_x.get_shape()[-1]),
                        self.config.out_caps_num,
                        iter=self.config.rout_iter,
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='rout')
                elif self.config.attn_mode == 'Rrout':
                    flatten_attn_ctx = self.reverse_routing_masked(
                        flatten_birnn_x,
                        flatten_wNum,
                        int(flatten_birnn_x.get_shape()[-1]),
                        self.config.out_caps_num,
                        iter=self.config.rout_iter,
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='Rrout')
                else:
                    raise ValueError('no such attn mode %s' %
                                     self.config.attn_mode)
            snt_dim = int(flatten_attn_ctx.get_shape()[-1])
            snt_reps = tf.reshape(flatten_attn_ctx,
                                  shape=[b_sz, ststp, snt_dim])

            with tf.variable_scope('doc_enc'):
                if self.config.seq_encoder == 'bigru':
                    birnn_snt = self.biGRU(snt_reps,
                                           sNum,
                                           self.config.hidden_size,
                                           scope='biGRU')
                elif self.config.seq_encoder == 'bilstm':
                    birnn_snt = self.biLSTM(snt_reps,
                                            sNum,
                                            self.config.hidden_size,
                                            scope='biLSTM')
                else:
                    raise ValueError('no such encoder %s' %
                                     self.config.seq_encoder)
                '''shape(b_sz, dim)'''
                if self.config.attn_mode == 'avg':
                    doc_rep = reduce_avg(birnn_snt, sNum, dim=1)
                elif self.config.attn_mode == 'max':
                    doc_rep = tf.reduce_max(birnn_snt, axis=1)
                elif self.config.attn_mode == 'attn':
                    doc_rep = self.task_specific_attention(
                        birnn_snt,
                        sNum,
                        int(birnn_snt.get_shape()[-1]),
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='attention')
                elif self.config.attn_mode == 'rout':
                    doc_rep = self.routing_masked(
                        birnn_snt,
                        sNum,
                        int(birnn_snt.get_shape()[-1]),
                        self.config.out_caps_num,
                        iter=self.config.rout_iter,
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='attention')
                elif self.config.attn_mode == 'Rrout':
                    doc_rep = self.reverse_routing_masked(
                        birnn_snt,
                        sNum,
                        int(birnn_snt.get_shape()[-1]),
                        self.config.out_caps_num,
                        iter=self.config.rout_iter,
                        dropout=self.config.dropout,
                        is_train=self.ph_train,
                        scope='attention')
                else:
                    raise ValueError('no such attn mode %s' %
                                     self.config.attn_mode)
        return doc_rep
        def feed_back_lstm(inputs):
            def feed_back_net(inputs, seq_len, feed_back_steps):
                '''
                Args:
                    inputs: shape(b_sz, tstp, emb_sz)
                '''
                shape_of_input = tf.shape(inputs)
                b_sz = shape_of_input[0]
                h_sz = self.config.hidden_size
                tstp = shape_of_input[1]
                emb_sz = self.config.embed_size

                def body(time, prev_output, state_ta):
                    '''
                    Args:
                        prev_output: previous output shape(b_sz, tstp, hidden_size)
                    '''

                    prev_output = tf.reshape(prev_output,
                                             shape=[-1, h_sz
                                                    ])  #shape(b_sz*tstp, h_sz)
                    output_linear = tf.nn.rnn_cell._linear(
                        prev_output,
                        output_size=h_sz,  #shape(b_sz*tstp, h_sz)
                        bias=False,
                        scope='output_transformer')
                    output_linear = tf.reshape(
                        output_linear, shape=[b_sz, tstp,
                                              h_sz])  #shape(b_sz, tstp, h_sz)
                    output_linear = tf.tanh(
                        output_linear)  #shape(b_sz, tstp, h_sz)

                    rnn_input = tf.concat(2, [output_linear, inputs],
                                          name='concat_output_input'
                                          )  #shape(b_sz, tstp, h_sz+emb_sz)

                    cell = tf.nn.rnn_cell.BasicLSTMCell(h_sz)
                    cur_outputs, state = tf.nn.dynamic_rnn(cell,
                                                           rnn_input,
                                                           seq_len,
                                                           dtype=tf.float32,
                                                           swap_memory=True,
                                                           time_major=False,
                                                           scope='encoder')
                    state = tf.concat(1, state)
                    state_ta = state_ta.write(time, state)
                    return time + 1, cur_outputs, state_ta  #shape(b_sz, tstp, h_sz)

                def condition(time, *_):
                    return time < feed_back_steps

                state_ta = tf.TensorArray(dtype=tf.float32,
                                          dynamic_size=True,
                                          clear_after_read=True,
                                          size=0)
                initial_output = tf.zeros(shape=[b_sz, tstp, h_sz],
                                          dtype=inputs.dtype,
                                          name='initial_output')
                time = tf.constant(0, dtype=tf.int32)
                _, outputs, state_ta = tf.while_loop(
                    condition,
                    body, [time, initial_output, state_ta],
                    swap_memory=True)
                final_state = state_ta.read(state_ta.size() - 1)
                return final_state, outputs

            _, outputs = feed_back_net(inputs,
                                       self.ph_seqLen,
                                       feed_back_steps=10)

            mask = mkMask(self.ph_seqLen, tstp)  # b_sz, tstp
            mask = tf.expand_dims(mask, dim=2)  #b_sz, tstp, 1

            aggregate_state = reduce_avg(outputs,
                                         mask,
                                         tf.expand_dims(self.ph_seqLen, 1),
                                         dim=-2)  #b_sz, h_sz
            inputs = aggregate_state
            inputs = tf.reshape(inputs, [-1, self.config.hidden_size])

            for i in range(self.config.fnn_numLayers):
                inputs = rnn.rnn_cell._linear(inputs,
                                              self.config.hidden_size,
                                              bias=True,
                                              scope='fnn_layer-' + str(i))
                inputs = tf.nn.tanh(inputs)
            aggregate_state = inputs
            logits = rnn.rnn_cell._linear(aggregate_state,
                                          self.config.class_num,
                                          bias=True,
                                          scope='fnn_softmax')
            return logits