def add_model(self):
        """
            input_tensor #(batch_size, num_sentence, embed_size)
            input_len    #(batch_size)
        """

        b_sz = tf.shape(self.encoder_input)[0]
        tstp_enc = tf.shape(self.encoder_input)[1]
        tstp_dec = tf.shape(self.ph_decoder_label)[1]

        enc_in = self.encoder_input  # shape(b_sz, tstp_enc, s_emb_sz)
        enc_len = self.ph_input_encoder_len  # shape(b_sz,)
        dec_len = self.ph_input_encoder_len  # shape(b_sz,)
        order_idx = self.ph_decoder_label  # shape(b_sz, tstp_dec)

        cell_dec = rnn_cell.BasicLSTMCell(self.config.h_dec_sz)
        with tf.variable_scope('add_model') as vscope:
            out_logits = self.train_module(  # shape(b_sz, tstp_dec, tstp_enc)
                cell_dec,
                enc_in,
                enc_len,
                dec_len,
                order_idx,
                scope='decoder_train')
            vscope.reuse_variables()
            predict_idx = self.decoder_test(
                cell_dec, enc_in, enc_len, dec_len,
                scope='decoder_train')  # shape(b_sz, tstp_dec)

        train_loss, valid_loss = self.add_loss_op(out_logits, order_idx,
                                                  dec_len)

        return train_loss, valid_loss, predict_idx
Beispiel #2
0
        def basic_lstm_model(inputs):
            print "Loading basic lstm model.."
            for i in range(self.config.rnn_numLayers):
                with tf.variable_scope('rnnLayer' + str(i)):
                    lstm_cell = rnn_cell.BasicLSTMCell(self.config.hidden_size)
                    outputs, _ = tf.nn.dynamic_rnn(
                        lstm_cell,
                        inputs,
                        self.ph_seqLen,  #(b_sz, tstp, h_sz)
                        dtype=tf.float32,
                        swap_memory=True,
                        scope='basic_lstm_model_layer-' + str(i))
                    inputs = outputs  #b_sz, tstp, h_sz
            mask = TfUtils.mkMask(self.ph_seqLen, tstp)  # b_sz, tstp
            mask = tf.expand_dims(mask, axis=2)  #b_sz, tstp, 1

            aggregate_state = TfUtils.reduce_avg(outputs,
                                                 self.ph_seqLen,
                                                 dim=1)  #b_sz, h_sz
            inputs = aggregate_state
            inputs = tf.reshape(inputs, [-1, self.config.hidden_size])

            for i in range(self.config.fnn_numLayers):
                inputs = TfUtils.linear(inputs,
                                        self.config.hidden_size,
                                        bias=True,
                                        scope='fnn_layer-' + str(i))
                inputs = tf.nn.tanh(inputs)
            aggregate_state = inputs
            logits = TfUtils.linear(aggregate_state,
                                    self.config.class_num,
                                    bias=True,
                                    scope='fnn_softmax')
            return logits
        def lstm_sentence_rep(input):
            with tf.variable_scope('lstm_sentence_rep_scope') as scope:
                input = tf.reshape(input,
                                   shape=[b_sz * tstps_en, -1, emb_sz
                                          ])  #(b_sz*tstps_en, len_sen, emb_sz)
                length = tf.reshape(self.ph_input_encoder_sentence_len,
                                    shape=[-1])  #(b_sz*tstps_en)

                lstm_cell = rnn_cell.BasicLSTMCell(h_sz)
                """tup(shape(b_sz*tstp_enc, len_sen, h_sz))"""
                rep_out, _ = tf.nn.bidirectional_dynamic_rnn(  # tup(shape(b_sz*tstp_enc, len_sen, h_sz))
                    lstm_cell,
                    lstm_cell,
                    input,
                    length,
                    dtype=tf.float32,
                    swap_memory=True,
                    time_major=False,
                    scope='sentence_encode')

                rep_out = tf.concat(
                    axis=2, values=rep_out)  #(b_sz*tstps_en, len_sen, h_sz*2)
                rep_out = TfUtils.reduce_avg(
                    rep_out, length, dim=1)  # shape(b_sz*tstps_en, h_sz*2)
                output = tf.reshape(rep_out,
                                    shape=[b_sz, tstps_en, 2 * h_sz
                                           ])  #(b_sz, tstps_en, h_sz*2)

            return output, None, None
    def train_module(self,
                     cell_dec,
                     encoder_inputs,
                     enc_lengths,
                     dec_lengths,
                     order_index,
                     scope=None):
        '''
        Args:
            cell_dec : lstm cell object, a configuration
            encoder_inputs : shape(b_sz, tstp_enc, s_emb_sz)
            enc_lengths : shape(b_sz,), encoder input lengths
            dec_lengths : shape(b_sz), decoder input lengths
            order_index : shape(b_sz, tstp_dec), decoder label

        '''
        small_num = -np.Inf
        input_shape = tf.shape(encoder_inputs)
        b_sz = input_shape[0]
        tstp_enc = input_shape[1]
        tstp_dec = tstp_enc  # since no noise, time step of decoder should be the same as encoder
        h_enc_sz = self.config.h_enc_sz
        h_dec_sz = self.config.h_dec_sz
        s_emb_sz = np.int(encoder_inputs.get_shape()
                          [2])  # should be a python-determined number

        cell_enc = rnn_cell.BasicLSTMCell(self.config.h_enc_sz)

        def enc(dec_h, in_x, lengths, fake_call=False):
            '''
            Args:
                dec_h: shape(b_sz, tstp_dec, h_dec_sz)
                in_x: shape(b_sz, tstp_enc, s_emb_sz)
                lengths: shape(b_sz)
            Returns:
                res: shape(b_sz, tstp_dec, tstp_enc, Ptr_sz)
            '''
            def func_f(in_x, enc_h, in_h_hat, fake_call=False):
                '''
                Args:
                    in_x: shape(b_sz, tstp_dec, tstp_enc, enc_emb_sz)
                    in_h: shape(b_sz, tstp_dec, tstp_enc, h_enc_sz*2)
                Returns:
                    res: shape(b_sz, tstp_dec, tstp_enc, enc_emb_sz+h_enc_sz*2)

                '''
                if fake_call:
                    return s_emb_sz + h_enc_sz * 4

                in_x_sz = int(in_x.get_shape()[-1])
                in_h_sz = int(enc_h.get_shape()[-1])
                if not in_x_sz:
                    assert ValueError('last dimension of the first' +
                                      ' arg should be known, while got %s' %
                                      (str(type(in_x_sz))))
                if not in_h_sz:
                    assert ValueError('last dimension of the second' +
                                      ' arg should be known, while got %s' %
                                      (str(type(in_h_sz))))
                enc_in_ex = tf.expand_dims(
                    in_x, 1)  # shape(b_sz, 1, tstp_enc, s_emb_sz)
                enc_in = tf.tile(
                    enc_in_ex,  # shape(b_sz, tstp_dec, tstp_enc, s_emb_sz)
                    [1, tstp_dec, 1, 1])
                res = tf.concat(axis=3, values=[enc_in, enc_h, in_h_hat])
                return res  # shape(b_sz, tstp_dec, tstp_enc, enc_emb_sz+h_enc_sz*4)

            def attend(enc_h, enc_len):
                '''
                Args:
                    enc_h: shape(b_sz, tstp_dec, tstp_enc, h_enc_sz*2)
                    enc_len: shape(b_sz)
                '''
                enc_len = tf.expand_dims(enc_len, 1)  # shape(b_sz, 1)
                attn_enc_len = tf.tile(enc_len, [1, tstp_dec])
                attn_enc_len = tf.reshape(attn_enc_len, [b_sz * tstp_dec])
                attn_enc_h = tf.reshape(
                    enc_h,  # shape(b_sz*tstp_dec, tstp_enc, h_enc_sz*2)
                    [b_sz * tstp_dec, tstp_enc,
                     np.int(enc_h.get_shape()[-1])])
                attn_out = TfUtils.self_attn(  # shape(b_sz*tstp_dec, tstp_enc, h_enc_sz*2)
                    attn_enc_h, attn_enc_len)
                h_hat = tf.reshape(
                    attn_out,  # shape(b_sz, tstp_dec, tstp_enc, h_enc_sz*2)
                    [
                        b_sz, tstp_dec, tstp_enc,
                        np.int(attn_out.get_shape()[-1])
                    ])
                return h_hat

            if fake_call:
                return func_f(None, None, None, fake_call=True)

            def get_lstm_in_len():
                inputs = func_enc_input(
                    dec_h, in_x)  # shape(b_sz, tstp_dec, tstp_enc, enc_emb_sz)
                enc_emb_sz = np.int(inputs.get_shape()[-1])
                enc_in = tf.reshape(
                    inputs, shape=[b_sz * tstp_dec, tstp_enc, enc_emb_sz])
                enc_len = tf.expand_dims(lengths, 1)  # shape(b_sz, 1)
                enc_len = tf.tile(enc_len,
                                  [1, tstp_dec])  # shape(b_sz, tstp_dec)
                enc_len = tf.reshape(
                    enc_len, [b_sz * tstp_dec])  # shape(b_sz*tstp_dec,)
                return enc_in, enc_len

            '''shape(b_sz*tstp_dec, tstp_enc, enc_emb_sz), shape(b_sz*tstp_dec)'''
            enc_in, enc_len = get_lstm_in_len()
            '''tup(shpae(b_sz*tstp_dec, tstp_enc, h_enc_sz))'''
            lstm_out, _ = tf.nn.bidirectional_dynamic_rnn(cell_enc,
                                                          cell_enc,
                                                          enc_in,
                                                          enc_len,
                                                          swap_memory=True,
                                                          dtype=tf.float32,
                                                          scope='sent_encoder')
            enc_out = tf.concat(
                axis=2,
                values=lstm_out)  # shape(b_sz*tstp_dec, tstp_enc, h_enc_sz*2)
            enc_out = tf.reshape(
                enc_out,  # shape(b_sz, tstp_dec, tstp_enc, h_enc_sz*2)
                shape=[b_sz, tstp_dec, tstp_enc, h_enc_sz * 2])

            enc_out_hat = attend(enc_out, lengths)
            res = func_f(in_x, enc_out, enc_out_hat)
            return res  # shape(b_sz, tstp_dec, tstp_enc, Ptr_sz)

        def func_enc_input(dec_h, enc_input, fake_call=False):
            '''
            Args:
                enc_input: encoder input, shape(b_sz, tstp_enc, s_emb_sz)
                dec_h: decoder hidden state, shape(b_sz, tstp_dec, h_dec_sz)
            Returns:
                output: shape(b_sz, tstp_dec, tstp_enc, s_emb_sz+h_dec_sz)
            '''
            enc_emb_sz = s_emb_sz + h_dec_sz
            if fake_call:
                return enc_emb_sz

            dec_h_ex = tf.expand_dims(dec_h,
                                      2)  # shape(b_sz, tstp_dec, 1, h_dec_sz)
            dec_h_tile = tf.tile(
                dec_h_ex,  # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz)
                [1, 1, tstp_enc, 1])
            enc_in_ex = tf.expand_dims(enc_input,
                                       1)  # shape(b_sz, 1, tstp_enc, s_emb_sz)
            enc_in_tile = tf.tile(
                enc_in_ex,  # shape(b_sz, tstp_dec, tstp_enc, s_emb_sz)
                [1, tstp_dec, 1, 1])
            output = tf.concat(
                axis=3,  # shape(b_sz, tstp_dec, tstp_enc, s_emb_sz+h_dec_sz)
                values=[enc_in_tile, dec_h_tile])

            output = tf.reshape(
                output, shape=[b_sz, tstp_dec, tstp_enc, s_emb_sz + h_dec_sz])
            return output  # shape(b_sz, tstp_dec, tstp_enc, s_emb_sz+h_dec_sz)

        def func_point_logits(dec_h, enc_ptr, enc_len):
            '''
            Args:
                dec_h : shape(b_sz, tstp_dec, h_dec_sz)
                enc_ptr : shape(b_sz, tstp_dec, tstp_enc, Ptr_sz)
                enc_len : shape(b_sz,)
            '''
            dec_h_ex = tf.expand_dims(
                dec_h, axis=2)  # shape(b_sz, tstp_dec, 1, h_dec_sz)
            dec_h_ex = tf.tile(dec_h_ex,
                               [1, 1, tstp_enc, 1
                                ])  # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz)
            linear_concat = tf.concat(axis=3, values=[
                dec_h_ex, enc_ptr
            ])  # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz+ Ptr_sz)
            point_linear = TfUtils.last_dim_linear(  # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz)
                linear_concat,
                output_size=h_dec_sz,
                bias=False,
                scope='Ptr_W')
            point_v = TfUtils.last_dim_linear(  # shape(b_sz, tstp_dec, tstp_enc, 1)
                tf.tanh(point_linear),
                output_size=1,
                bias=False,
                scope='Ptr_V')

            point_logits = tf.squeeze(
                point_v, axis=[3])  # shape(b_sz, tstp_dec, tstp_enc)

            enc_len = tf.expand_dims(enc_len, 1)  # shape(b_sz, 1)
            enc_len = tf.tile(enc_len, [1, tstp_dec])  # shape(b_sz, tstp_dec)
            mask = TfUtils.mkMask(
                enc_len, maxLen=tstp_enc)  # shape(b_sz, tstp_dec, tstp_enc)
            point_logits = tf.where(
                mask,
                point_logits,  # shape(b_sz, tstp_dec, tstp_enc)
                tf.ones_like(point_logits) * small_num)

            return point_logits

        def get_initial_state(hidden_sz):
            '''
            Args:
                hidden_sz: must be a python determined number
            '''
            avg_in_x = TfUtils.reduce_avg(
                encoder_inputs,  # shape(b_sz, s_emb_sz)
                enc_lengths,
                dim=1)
            state = TfUtils.linear(
                avg_in_x,
                hidden_sz,  # shape(b_sz, hidden_sz)
                bias=False,
                scope='initial_transformation')
            state = rnn_cell.LSTMStateTuple(state, tf.zeros_like(state))
            return state

        def get_bos(emb_sz):
            with tf.variable_scope('bos_scope') as vscope:
                try:
                    ret = tf.get_variable(name='bos',
                                          shape=[1, emb_sz],
                                          dtype=tf.float32)
                except:
                    vscope.reuse_variables()
                    ret = tf.get_variable(name='bos',
                                          shape=[1, emb_sz],
                                          dtype=tf.float32)
            ret_bos = tf.tile(ret, [b_sz, 1])
            return ret_bos

        def decoder():
            def get_dec_in():
                dec_in = TfUtils.batch_embed_lookup(
                    encoder_inputs,
                    order_index)  # shape(b_sz, tstp_dec, s_emb_sz)
                bos = get_bos(s_emb_sz)  # shape(b_sz, s_emb_sz)
                bos = tf.expand_dims(bos, 1)  # shape(b_sz, 1, s_smb_sz)
                dec_in = tf.concat(
                    axis=1,
                    values=[bos, dec_in])  # shape(b_sz, tstp_dec+1, s_emb_sz)
                dec_in = dec_in[:, :-1, :]  # shape(b_sz, tstp_dec, s_emb_sz)
                return dec_in

            dec_in = get_dec_in()  # shape(b_sz, tstp_dec, s_emb_sz)
            initial_state = get_initial_state(
                h_dec_sz)  # shape(b_sz, h_dec_sz)
            dec_out, _ = tf.nn.dynamic_rnn(
                cell_dec,
                dec_in,  # shape(b_sz, tstp_dec, h_dec_sz)
                dec_lengths,
                initial_state=initial_state,
                swap_memory=True,
                dtype=tf.float32,
                scope=scope)
            with tf.variable_scope(scope):
                enc_out = enc(
                    dec_out,  # shape(b_sz, tstp_dec, tstp_enc, Ptr_sz)
                    encoder_inputs,
                    enc_lengths)
                point_logits = func_point_logits(
                    dec_out, enc_out,
                    enc_lengths)  # shape(b_sz, tstp_dec, tstp_enc)
            return point_logits

        point_logits = decoder()  # shape(b_sz, tstp_dec, tstp_enc)
        return point_logits
    def decoder_test(self,
                     cell_dec,
                     encoder_inputs,
                     enc_lengths,
                     dec_lengths,
                     scope=None):
        '''
        Args:
            cell_dec : lstm cell object, a configuration
            encoder_inputs : shape(b_sz, tstp_enc, s_emb_sz)
            enc_lengths : shape(b_sz,), encoder input lengths
            dec_lengths : shape(b_sz), decoder input lengths
            order_index : shape(b_sz, tstp_dec), decoder label

        '''

        small_num = -np.Inf
        input_shape = tf.shape(encoder_inputs)
        b_sz = input_shape[0]
        tstp_enc = input_shape[1]
        tstp_dec = tstp_enc  # since no noise, time step of decoder should be the same as encoder
        h_enc_sz = self.config.h_enc_sz
        h_dec_sz = self.config.h_dec_sz
        s_emb_sz = np.int(encoder_inputs.get_shape()
                          [2])  # should be a python-determined number

        # dec_emb_sz not determined
        cell_enc = rnn_cell.BasicLSTMCell(self.config.h_enc_sz)

        def enc(dec_h, in_x, lengths, fake_call=False):
            '''
            Args:
                inputs: shape(b_sz, tstp_enc, enc_emb_sz)

            '''
            def func_f(in_x, in_h, in_h_hat, fake_call=False):
                if fake_call:
                    return s_emb_sz + h_enc_sz * 4

                in_x_sz = int(in_x.get_shape()[-1])
                in_h_sz = int(in_h.get_shape()[-1])
                if not in_x_sz:
                    assert ValueError('last dimension of the first' +
                                      ' arg should be known, while got %s' %
                                      (str(type(in_x_sz))))
                if not in_h_sz:
                    assert ValueError('last dimension of the second' +
                                      ' arg should be known, while got %s' %
                                      (str(type(in_h_sz))))
                res = tf.concat(axis=2, values=[in_x, in_h, in_h_hat])
                return res

            if fake_call:
                return func_f(None, None, None, fake_call=True)
            inputs = func_enc_input(dec_h, in_x)

            lstm_out, _ = tf.nn.bidirectional_dynamic_rnn(cell_enc,
                                                          cell_enc,
                                                          inputs,
                                                          lengths,
                                                          swap_memory=True,
                                                          dtype=tf.float32,
                                                          scope='sent_encoder')
            enc_out = tf.concat(
                axis=2, values=lstm_out)  # shape(b_sz, tstp_enc, h_enc_sz*2)
            enc_out = tf.reshape(enc_out, [b_sz, tstp_enc, h_enc_sz * 2])

            enc_out_hat = TfUtils.self_attn(enc_out, lengths)
            res = func_f(in_x, enc_out, enc_out_hat)
            return res  # shape(b_sz, tstp_enc, dec_emb_sz)

        def func_enc_input(dec_h, enc_input, fake_call=False):
            '''
            Args:
                enc_input: encoder input, shape(b_sz, tstp_enc, s_emb_sz)
                dec_h: decoder hidden state, shape(b_sz, h_dec_sz)
            '''
            enc_emb_sz = s_emb_sz + h_dec_sz
            if fake_call:
                return enc_emb_sz

            dec_h_ex = tf.expand_dims(dec_h, 1)  # shape(b_sz, 1, h_dec_sz)
            dec_h_tile = tf.tile(dec_h_ex, [1, tstp_enc, 1])

            output = tf.concat(axis=2, values=[
                enc_input, dec_h_tile
            ])  # shape(b_sz, tstp_enc, s_emb_sz + h_dec_sz)
            output = tf.reshape(output,
                                shape=[b_sz, tstp_enc, s_emb_sz + h_dec_sz])
            return output  # shape(b_sz, tstp_enc, s_emb_sz + h_dec_sz)

        enc_emb_sz = func_enc_input(None, None, fake_call=True)
        dec_emb_sz = enc(None, None, None, fake_call=True)

        def func_point_logits(dec_h, enc_e, enc_len):
            '''
            Args:
                dec_h : shape(b_sz, h_dec_sz)
                enc_e : shape(b_sz, tstp_enc, dec_emb_sz)
                enc_len : shape(b_sz,)
            '''

            dec_h_ex = tf.expand_dims(dec_h,
                                      axis=1)  # shape(b_sz, 1, h_dec_sz)
            dec_h_ex = tf.tile(
                dec_h_ex, [1, tstp_enc, 1])  # shape(b_sz, tstp_enc, h_dec_sz)
            linear_concat = tf.concat(axis=2, values=[
                dec_h_ex, enc_e
            ])  # shape(b_sz, tstp_enc, h_dec_sz+ dec_emb_sz)
            point_linear = TfUtils.last_dim_linear(  # shape(b_sz, tstp_enc, h_dec_sz)
                linear_concat,
                output_size=h_dec_sz,
                bias=False,
                scope='Ptr_W')
            point_v = TfUtils.last_dim_linear(  # shape(b_sz, tstp_enc, 1)
                tf.tanh(point_linear),
                output_size=1,
                bias=False,
                scope='Ptr_V')
            point_logits = tf.squeeze(point_v,
                                      axis=[2])  # shape(b_sz, tstp_enc)
            mask = TfUtils.mkMask(enc_len,
                                  maxLen=tstp_enc)  # shape(b_sz, tstp_enc)
            point_logits = tf.where(mask, point_logits,
                                    tf.ones_like(point_logits) *
                                    small_num)  # shape(b_sz, tstp_enc)

            return point_logits

        def func_point_idx(dec_h, enc_e, enc_len, hit_mask):
            '''
            Args:
                hit_mask: shape(b_sz, tstp_enc)
            '''
            logits = func_point_logits(dec_h, enc_e,
                                       enc_len)  # shape(b_sz, tstp_enc)
            prob = tf.nn.softmax(logits)
            prob = tf.where(hit_mask,
                            tf.zeros_like(prob),
                            prob,
                            name='mask_hit_pos')
            idx = tf.cast(tf.arg_max(prob, dimension=1),
                          dtype=tf.int32)  # shape(b_sz,) type of int32
            return idx  # shape(b_sz,)

        def get_bos(emb_sz):
            with tf.variable_scope('bos_scope') as vscope:
                try:
                    ret = tf.get_variable(name='bos',
                                          shape=[1, emb_sz],
                                          dtype=tf.float32)
                except:
                    vscope.reuse_variables()
                    ret = tf.get_variable(name='bos',
                                          shape=[1, emb_sz],
                                          dtype=tf.float32)
            ret_bos = tf.tile(ret, [b_sz, 1])
            return ret_bos

        def get_initial_state(hidden_sz):
            '''
            Args:
                hidden_sz: must be a python determined number
            '''
            avg_in_x = TfUtils.reduce_avg(
                encoder_inputs,  # shape(b_sz, s_emb_sz)
                enc_lengths,
                dim=1)
            state = TfUtils.linear(
                avg_in_x,
                hidden_sz,  # shape(b_sz, hidden_sz)
                bias=False,
                scope='initial_transformation')
            state = rnn_cell.LSTMStateTuple(state, tf.zeros_like(state))
            return state

        bos = get_bos(s_emb_sz)  # shape(b_sz, s_emb_sz)

        init_state = get_initial_state(h_dec_sz)

        def loop_fn(time, cell_output, cell_state, hit_mask):
            """
            Args:
                cell_output: shape(b_sz, h_dec_sz) ==> d
                cell_state: tup(shape(b_sz, h_dec_sz))
                pointer_logits_ta: pointer logits tensorArray
                hit_mask: shape(b_sz, tstp_enc)
            """

            if cell_output is None:  # time == 0
                next_cell_state = init_state
                next_input = bos  # shape(b_sz, dec_emb_sz)
                next_idx = tf.zeros(shape=[b_sz],
                                    dtype=tf.int32)  # shape(b_sz, tstp_enc)
                elements_finished = tf.zeros(shape=[b_sz],
                                             dtype=tf.bool,
                                             name='elem_finished')
                next_hit_mask = tf.zeros(shape=[b_sz, tstp_enc],
                                         dtype=tf.bool,
                                         name='hit_mask')
            else:

                next_cell_state = cell_state

                encoder_e = enc(
                    cell_output, encoder_inputs,
                    enc_lengths)  # shape(b_sz, tstp_enc, dec_emb_sz)
                next_idx = func_point_idx(cell_output, encoder_e, enc_lengths,
                                          hit_mask)  # shape(b_sz,)

                cur_hit_mask = tf.one_hot(
                    next_idx,
                    on_value=True,  # shape(b_sz, tstp_enc)
                    off_value=False,
                    depth=tstp_enc,
                    dtype=tf.bool)
                next_hit_mask = tf.logical_or(
                    hit_mask,
                    cur_hit_mask,  # shape(b_sz, tstp_enc)
                    name='next_hit_mask')

                next_input = TfUtils.batch_embed_lookup(
                    encoder_inputs, next_idx)  # shape(b_sz, s_emb_sz)

                elements_finished = (time >= dec_lengths)  # shape(b_sz,)

            return (elements_finished, next_input, next_cell_state,
                    next_hit_mask, next_idx)

        emit_idx_ta, _ = myRNN.train_rnn(cell_dec, loop_fn, scope=scope)
        output_idx = emit_idx_ta.stack()  # shape(tstp_dec, b_sz)
        output_idx = tf.transpose(output_idx,
                                  perm=[1, 0])  # shape(b_sz, tstp_dec)

        return output_idx  # shape(b_sz, tstp_dec)
Beispiel #6
0
    def add_logits_op2(self):
        """利用BLSTM生成结果,batch中每个句子的每个单词都有一个结果,一个结果是n维的变量,n大小为类别的数目"""
        with tf.variable_scope('Premise_encoder'):
            lstm_cell = rnn_cell.BasicLSTMCell(hidden_size_lstm)
            lstm_cell = rnn_cell.DropoutWrapper(lstm_cell,
                                                input_keep_prob=self.dropout,
                                                output_keep_prob=self.dropout)
            Premise_out, Premise_state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=lstm_cell,
                cell_bw=lstm_cell,
                inputs=self.seq1_word_embeddings,
                sequence_length=self.sequence1_lengths,
                dtype=tf.float32,
                swap_memory=True)
            Premise_output_fw, Premise_output_bw = Premise_out
            Premise_states_fw, Premise_states_bw = Premise_state
            Premise_out = tf.concat(Premise_out, 2)
            Premise_state = tf.concat(Premise_state, 2)
        with tf.variable_scope('Hypothesis_encoder'):
            lstm_cell = rnn_cell.BasicLSTMCell(hidden_size_lstm)
            lstm_cell = rnn_cell.DropoutWrapper(lstm_cell,
                                                input_keep_prob=self.dropout,
                                                output_keep_prob=self.dropout)
            Hypo_out, Hypo_state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=lstm_cell,
                cell_bw=lstm_cell,
                inputs=self.seq2_word_embeddings,
                sequence_length=self.sequence2_lengths,
                # initial_state_fw=Premise_states_fw,
                # initial_state_bw=Premise_states_bw,
                dtype=tf.float32,
                swap_memory=True)
            print('before=', np.shape(Hypo_state[1]))
            Hypo_out = tf.concat(Hypo_out, 2)
            Hypo_state = tf.concat(Hypo_state, 2)

        def w2w_attn(Premise_out,
                     Hypo_out,
                     seqLen_Premise,
                     seqLen_Hypo,
                     scope=None):
            with tf.variable_scope(scope or 'Attn_layer'):
                attn_cell = AttnCell(196 * 2, Premise_out, seqLen_Premise)
                attn_cell = rnn_cell.DropoutWrapper(
                    attn_cell,
                    input_keep_prob=self.dropout,
                    output_keep_prob=self.dropout)

                _, r_state = tf.nn.dynamic_rnn(attn_cell,
                                               Hypo_out,
                                               seqLen_Hypo,
                                               dtype=Hypo_out.dtype,
                                               swap_memory=True)
            return r_state

        r_L = w2w_attn(Premise_out,
                       Hypo_out,
                       self.sequence1_lengths,
                       self.sequence2_lengths,
                       scope='w2w_attention')

        hypo_state1 = tf.reshape(Hypo_state[1], [-1, 392])
        hypo_state1 = tf.nn.dropout(hypo_state1, 0.5)
        print('***********', np.shape(r_L))
        print('***********', np.shape(hypo_state1))

        h_star = tf.tanh(
            linear(
                [r_L, hypo_state1],  # shape (b_sz, h_sz)
                392,
                bias=False,
                scope='linear_trans'))
        input_fully = h_star
        output = tf.nn.dropout(input_fully, self.dropout)
        W = tf.get_variable('W',
                            dtype=tf.float32,
                            shape=[hidden_size_lstm * 2, ntags])
        b = tf.get_variable('b',
                            dtype=tf.float32,
                            shape=[ntags],
                            initializer=tf.zeros_initializer())
        pred = tf.matmul(output, W) + b
        logits = tf.reshape(pred, [-1, ntags])
        logits = tf.nn.softmax(logits)
        self.logits = logits
        '''
        for i in range(2):
            with tf.variable_scope('fully_connect_'+str(i)):
                logits = tf.contrib.layers.fully_connected(
                    input_fully, 300 * 2, activation_fn=None)
                input_fully = tf.tanh(logits)
        with tf.name_scope('Softmax'):
            logits = tf.contrib.layers.fully_connected(
                input_fully, self.config.class_num, activation_fn=None)
        self.logits = logits
        '''
        '''

        output1 = attention((output_fw12, output_bw12),
                            attention_size, return_alphas=False)
        # output_fw21 = tf.concat([output_fw2, output_fw1], axis=1)
        # output_bw21 = tf.concat([output_bw2, output_bw1], axis=1)
        output_fw21 = output_fw2
        output_bw21 = output_bw2
        output2 = attention((output_fw21, output_bw21),
                            attention_size, return_alphas=False)
        # output = output1 + output2
        print('output1=', np.shape(output1))
        print('output2=', np.shape(output2))
        output = tf.concat([output1, output2], axis=1)
        output = tf.nn.dropout(output, self.dropout)  # dropout
        print('shape of output=', np.shape(output))
        # 接下来构造映射层
        # W = tf.get_variable('W', dtype=tf.float32,
        #                     shape=[hidden_size_lstm*2, ntags])
        W = tf.get_variable('W', dtype=tf.float32,
                            shape=[hidden_size_lstm * 4, ntags])
        b = tf.get_variable('b', dtype=tf.float32, shape=[ntags],
                            initializer=tf.zeros_initializer())
        pred = tf.matmul(output, W) + b
        logits = tf.reshape(pred, [-1, ntags])
        logits = tf.nn.softmax(logits)
        self.logits = logits
        '''
        pass
Beispiel #7
0
    def add_model(self, input_x1, input_x2, seqLen_x1, seqLen_x2):
        '''
        dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
                dtype=None, parallel_iterations=None, swap_memory=False,
                time_major=False, scope=None):
        '''
        with tf.variable_scope('Premise_encoder'):
            lstm_cell = rnn_cell.BasicLSTMCell(self.config.hidden_size)
            lstm_cell = rnn_cell.DropoutWrapper(
                lstm_cell,
                input_keep_prob=self.config.dropout,
                output_keep_prob=self.config.dropout)
            Premise_out, Premise_state = tf.nn.dynamic_rnn(
                cell=lstm_cell,
                inputs=input_x1,
                sequence_length=seqLen_x1,
                dtype=tf.float32,
                swap_memory=True)
        with tf.variable_scope('Hypothesis_encoder'):
            lstm_cell = rnn_cell.BasicLSTMCell(self.config.hidden_size)
            lstm_cell = rnn_cell.DropoutWrapper(
                lstm_cell,
                input_keep_prob=self.config.dropout,
                output_keep_prob=self.config.dropout)
            Hypo_out, Hypo_state = tf.nn.dynamic_rnn(
                cell=lstm_cell,
                inputs=input_x2,
                sequence_length=seqLen_x2,
                initial_state=Premise_state,
                swap_memory=True)

        def w2w_attn(Premise_out,
                     Hypo_out,
                     seqLen_Premise,
                     seqLen_Hypo,
                     scope=None):
            with tf.variable_scope(scope or 'Attn_layer'):
                attn_cell = AttnCell(self.config.hidden_size, Premise_out,
                                     seqLen_Premise)
                attn_cell = rnn_cell.DropoutWrapper(
                    attn_cell,
                    input_keep_prob=self.config.dropout,
                    output_keep_prob=self.config.dropout)

                _, r_state = tf.nn.dynamic_rnn(attn_cell,
                                               Hypo_out,
                                               seqLen_Hypo,
                                               dtype=Hypo_out.dtype,
                                               swap_memory=True)
            return r_state

        r_L = w2w_attn(Premise_out,
                       Hypo_out,
                       seqLen_x1,
                       seqLen_x2,
                       scope='w2w_attention')

        h_star = tf.tanh(
            linear(
                [r_L, Hypo_state[1]],  # shape (b_sz, h_sz)
                self.config.hidden_size,
                bias=False,
                scope='linear_trans'))
        input_fully = h_star
        for i in range(self.config.fnn_layers):
            with tf.variable_scope('fully_connect_' + str(i)):
                logits = tf.contrib.layers.fully_connected(
                    input_fully,
                    self.config.hidden_size * 2,
                    activation_fn=None)
                input_fully = tf.tanh(logits)
        with tf.name_scope('Softmax'):
            logits = tf.contrib.layers.fully_connected(input_fully,
                                                       self.config.class_num,
                                                       activation_fn=None)
        return logits