예제 #1
0
    def build_model3(self):
        assert not conf.hidden_size % 2  # 被2整除
        self.uttn_enc_hidden_size = conf.hidden_size // 2
        batch_size, num_turns, length = shape_list(self.multi_s1)

        # embedding
        # [batch,len,turn,embed]
        multi_s1_embed, _ = embedding(tf.expand_dims(self.multi_s1, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=conf.pretrain_emb)
        # [batch,len,embed]
        s2_embed, _ = embedding(tf.expand_dims(self.s2, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=conf.pretrain_emb)

        # uttn encoder
        uttn_input = tf.reshape(multi_s1_embed, [-1, length, conf.embed_size])  # [batch*turn,len,embed]
        uttn_mask = mask_nonpad_from_embedding(uttn_input)  # [batch*turn,len] 1 for nonpad; 0 for pad
        uttn_seqlen = tf.cast(tf.reduce_sum(uttn_mask, axis=-1), tf.int32)  # [batch*turn]
        # uttn-gru
        self.encoder_uttn_rnn = Bi_RNN(cell_name='GRUCell', name='uttn_enc', hidden_size=self.uttn_enc_hidden_size, dropout_rate=self.dropout_rate)
        _, uttn_embed = self.encoder_uttn_rnn(uttn_input, uttn_seqlen)  # [batch*turn,len,2hid] [batch*turn,2hid]
        uttn_embed = tf.reshape(uttn_embed, [batch_size, num_turns, self.uttn_enc_hidden_size * 2])  # [batch,turn,2hid]  # 之后turn相当于len

        # transformer ctx encoder
        encoder_valid_mask = mask_nonpad_from_embedding(uttn_embed)  # [batch,turn] 1 for nonpad; 0 for pad
        encoder_input = add_timing_signal_1d(uttn_embed)  # add position embedding
        encoder_input = tf.layers.dropout(encoder_input, rate=self.dropout_rate)  # dropout

        encoder_output = transformer_encoder(encoder_input, encoder_valid_mask,
                                             hidden_size=conf.hidden_size,
                                             filter_size=conf.hidden_size * 4,
                                             num_heads=conf.num_heads,
                                             num_encoder_layers=conf.num_encoder_layers,
                                             dropout=self.dropout_rate,
                                             attention_dropout=self.dropout_rate,
                                             relu_dropout=self.dropout_rate,
                                             )

        # transformer decoder
        decoder_input = s2_embed
        decoder_valid_mask = mask_nonpad_from_embedding(decoder_input)  # [batch,len] 1 for nonpad; 0 for pad
        decoder_input = shift_right(decoder_input)  # 用pad当做eos
        decoder_input = add_timing_signal_1d(decoder_input)
        decoder_input = tf.layers.dropout(decoder_input, rate=self.dropout_rate)  # dropout

        decoder_output = transformer_decoder(decoder_input, encoder_output, decoder_valid_mask, encoder_valid_mask,
                                             cache=None,
                                             hidden_size=conf.hidden_size,
                                             filter_size=conf.hidden_size * 4,
                                             num_heads=conf.num_heads,
                                             num_decoder_layers=conf.num_decoder_layers,
                                             dropout=self.dropout_rate,
                                             attention_dropout=self.dropout_rate,
                                             relu_dropout=self.dropout_rate,
                                             )

        logits = proj_logits(decoder_output, conf.embed_size, conf.vocab_size, name='share_embedding')

        onehot_s2 = tf.one_hot(self.s2, depth=conf.vocab_size)  # [batch,len,vocab]

        xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=onehot_s2)  # [batch,len]
        weights = tf.to_float(tf.not_equal(self.s2, 0))  # [batch,len] 1 for nonpad; 0 for pad

        loss_num = xentropy * weights  # [batch,len]
        loss_den = weights  # [batch,len]

        loss = tf.reduce_sum(loss_num) / tf.reduce_sum(loss_den)  # scalar
        self.loss = loss

        # transformer decoder infer
        # 放在cache里面的在后面symbols_to_logits_fn函数中都会变成batch * beam
        # 初始化缓存
        cache = {
            'layer_%d' % layer: {
                # 用以缓存decoder过程前面已计算的k,v
                'k': split_heads(tf.zeros([batch_size, 0, conf.embed_size]), conf.num_heads),
                'v': split_heads(tf.zeros([batch_size, 0, conf.embed_size]), conf.num_heads)
            } for layer in range(conf.num_decoder_layers)
        }
        for layer in range(conf.num_decoder_layers):
            # 对于decoder每层均需与encoder顶层隐状态计算attention,相应均有特定的k,v可缓存
            layer_name = 'layer_%d' % layer
            with tf.variable_scope('decoder/%s/encdec_attention/multihead_attention' % layer_name):
                k_encdec = tf.layers.dense(encoder_output, conf.embed_size, use_bias=False, name='k', reuse=tf.AUTO_REUSE)
                k_encdec = split_heads(k_encdec, conf.num_heads)
                v_encdec = tf.layers.dense(encoder_output, conf.embed_size, use_bias=False, name='v', reuse=tf.AUTO_REUSE)
                v_encdec = split_heads(v_encdec, conf.num_heads)
            cache[layer_name]['k_encdec'] = k_encdec
            cache[layer_name]['v_encdec'] = v_encdec
        cache['encoder_output'] = encoder_output
        cache['encoder_mask'] = encoder_valid_mask

        # position embedding
        position_embedding = get_timing_signal_1d(conf.max_decode_len, conf.embed_size)  # +eos [1,length+1,embed]

        def symbols_to_logits_fn(ids, i, cache):
            ids = ids[:, -1:]  # [batch,1] 截取最后一个
            target_embed, _ = embedding(tf.expand_dims(ids, axis=-1), conf.vocab_size, conf.embed_size, 'share_embedding', reuse=True)  # [batch,1,hidden]

            decoder_input = target_embed + position_embedding[:, i:i + 1, :]  # [batch,1,hidden]

            encoder_output = cache['encoder_output']
            encoder_mask = cache['encoder_mask']

            with tf.variable_scope('', reuse=tf.AUTO_REUSE):
                decoder_output = transformer_decoder(decoder_input, encoder_output, None, encoder_mask,
                                                     cache=cache,  # 注意infer要cache
                                                     hidden_size=conf.embed_size,
                                                     filter_size=conf.embed_size * 4,
                                                     num_heads=6,
                                                     num_decoder_layers=6,
                                                     )
            logits = proj_logits(decoder_output, conf.embed_size, conf.vocab_size, name='share_embedding')  # [batch,1,vocab]
            ret = tf.squeeze(logits, axis=1)  # [batch,vocab]
            return ret, cache

        initial_ids = tf.zeros([batch_size], dtype=tf.int32)  # <pad>为<sos>

        def greedy_search_wrapper():
            """ Greedy Search """
            decoded_ids, scores = greedy_search(
                symbols_to_logits_fn,
                initial_ids,
                conf.max_decode_len,
                cache=cache,
                eos_id=conf.eos_id,
            )
            return decoded_ids, scores

        def beam_search_wrapper():
            """ Beam Search """
            decoded_ids, scores = beam_search(  # [batch,beam,len] [batch,beam]
                symbols_to_logits_fn,
                initial_ids,
                conf.beam_size,
                conf.max_decode_len,
                conf.vocab_size,
                alpha=0,
                states=cache,
                eos_id=conf.eos_id,
            )
            return decoded_ids, scores

        decoded_ids, scores = tf.cond(tf.equal(conf.beam_size, 1), greedy_search_wrapper, beam_search_wrapper)

        self.decoded_ids = tf.identity(decoded_ids, name='decoded_ids')  # [batch,beam/1,len]
        self.scores = tf.identity(scores, name='scores')  # [batch,beam/1]

        self.global_step = tf.train.get_or_create_global_step()
        self.optimizer = tf.train.AdamOptimizer(learning_rate=conf.lr)
        self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
예제 #2
0
    def build_model1(self):
        self.uttn_enc_hidden_size = 256
        self.ctx_enc_hidden_size = 256
        batch_size, num_turns, length = shape_list(self.multi_s1)

        # embedding
        # [batch,len,turn,embed]
        multi_s1_embed, _ = embedding(tf.expand_dims(self.multi_s1, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=conf.pretrain_emb)
        # [batch,len,embed]
        s2_embed, _ = embedding(tf.expand_dims(self.s2, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=conf.pretrain_emb)

        # uttn encoder
        uttn_input = tf.reshape(multi_s1_embed, [-1, length, conf.embed_size])  # [batch*turn,len,embed]
        uttn_mask = mask_nonpad_from_embedding(uttn_input)  # [batch*turn,len] 1 for nonpad; 0 for pad
        uttn_seqlen = tf.cast(tf.reduce_sum(uttn_mask, axis=-1), tf.int32)  # [batch*turn]
        # uttn-gru
        self.encoder_uttn_rnn = RNN(cell_name='GRUCell', name='uttn_enc', hidden_size=self.uttn_enc_hidden_size, dropout_rate=self.dropout_rate)
        _, uttn_embed = self.encoder_uttn_rnn(uttn_input, uttn_seqlen)  # [batch*turn,hid]
        uttn_embed = tf.reshape(uttn_embed, [batch_size, num_turns, self.uttn_enc_hidden_size])  # [batch,turn,hid]  # 之后turn相当于len

        # ctx encoder
        ctx_mask = mask_nonpad_from_embedding(uttn_embed)  # [batch,turn] 1 for nonpad; 0 for pad
        ctx_seqlen = tf.cast(tf.reduce_sum(ctx_mask, axis=-1), tf.int32)  # [batch]
        # ctx-gru
        self.encoder_ctx_rnn = RNN(cell_name='GRUCell', name='ctx_enc', hidden_size=self.ctx_enc_hidden_size, dropout_rate=self.dropout_rate)
        _, ctx_embed = self.encoder_ctx_rnn(uttn_embed, ctx_seqlen)  # [batch,hid]

        # rnn decoder train (no attention)
        s2_mask = mask_nonpad_from_embedding(s2_embed)  # [batch,len] 1 for nonpad; 0 for pad
        s2_seqlen = tf.cast(tf.reduce_sum(s2_mask, axis=-1), tf.int32)  # [batch]

        decoder_input = shift_right(s2_embed)  # 用pad当做eos
        decoder_input = tf.layers.dropout(decoder_input, rate=self.dropout_rate)  # dropout

        # 输入拼上ctx
        decoder_ctx = tf.tile(tf.expand_dims(ctx_embed, axis=1), [1, shape_list(decoder_input)[1], 1])  # [batch,len,hid]
        decoder_input = tf.concat([decoder_input, decoder_ctx], axis=2)

        self.decoder_rnn = RNN(cell_name='GRUCell', name='dec', hidden_size=conf.embed_size, dropout_rate=self.dropout_rate)
        decoder_output, decoder_state = self.decoder_rnn(decoder_input, s2_seqlen)

        logits = proj_logits(decoder_output, conf.embed_size, conf.vocab_size, name='share_embedding')

        onehot_s2 = tf.one_hot(self.s2, depth=conf.vocab_size)  # [batch,len,vocab]

        xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=onehot_s2)  # [batch,len]
        weights = tf.to_float(tf.not_equal(self.s2, 0))  # [batch,len] 1 for nonpad; 0 for pad

        loss_num = xentropy * weights  # [batch,len]
        loss_den = weights  # [batch,len]

        loss = tf.reduce_sum(loss_num) / tf.reduce_sum(loss_den)  # scalar
        self.loss = loss

        # rnn decoder infer (no attention)
        # 放在cache里面的在后面symbols_to_logits_fn函数中都会变成batch * beam
        cache = {'state': self.decoder_rnn.cell.zero_state(batch_size, tf.float32),  # [batch,hid]
                 'ctx': ctx_embed,  # [batch,hid]
                 }

        def symbols_to_logits_fn(ids, i, cache):
            # ids [batch,length]
            pred_target = ids[:, -1:]  # [batch,1] 截取最后一个
            target_embed, _ = embedding(tf.expand_dims(pred_target, axis=-1), conf.vocab_size, conf.embed_size, 'share_embedding')  # [batch,1,embed]
            decoder_input = tf.squeeze(target_embed, axis=1)  # [batch,embed]

            # 输入加上ctx
            decoder_input = tf.concat([decoder_input, cache['ctx']], axis=-1)

            # run rnn
            decoder_output, cache['state'] = self.decoder_rnn.one_step(decoder_input, cache['state'])

            logits = proj_logits(decoder_output, conf.embed_size, conf.vocab_size, name='share_embedding')

            return logits, cache

        initial_ids = tf.zeros([batch_size], dtype=tf.int32)  # <pad>为<sos>

        def greedy_search_wrapper():
            """ Greedy Search """
            decoded_ids, scores = greedy_search(
                symbols_to_logits_fn,
                initial_ids,
                conf.max_decode_len,
                cache=cache,
                eos_id=conf.eos_id,
            )
            return decoded_ids, scores

        def beam_search_wrapper():
            """ Beam Search """
            decoded_ids, scores = beam_search(  # [batch,beam,len] [batch,beam]
                symbols_to_logits_fn,
                initial_ids,
                conf.beam_size,
                conf.max_decode_len,
                conf.vocab_size,
                alpha=0,
                states=cache,
                eos_id=conf.eos_id,
            )
            return decoded_ids, scores

        decoded_ids, scores = tf.cond(tf.equal(conf.beam_size, 1), greedy_search_wrapper, beam_search_wrapper)

        self.decoded_ids = tf.identity(decoded_ids, name='decoded_ids')  # [batch,beam/1,len]
        self.scores = tf.identity(scores, name='scores')  # [batch,beam/1]

        self.global_step = tf.train.get_or_create_global_step()
        self.optimizer = tf.train.AdamOptimizer(learning_rate=conf.lr)
        self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
예제 #3
0
        def symbols_to_logits_fn(ids, i, cache):
            # ids [batch,length]
            pred_target = ids[:, -1:]  # [batch,1] 截取最后一个
            target_embed, _ = embedding(tf.expand_dims(pred_target, axis=-1), conf.vocab_size, conf.embed_size, 'share_embedding')  # [batch,1,embed]
            decoder_input = tf.squeeze(target_embed, axis=1)  # [batch,embed]

            dec_rnn_state = cache['dec_rnn_state']

            with tf.variable_scope('', reuse=tf.AUTO_REUSE):
                # 内层循环是对于解码每一步 ctx_rnn上turn的每一步,最终生成ctx序列 turn_step
                def inner_loop_condition(turn_step, *_):
                    return tf.less(turn_step, tf.reduce_max(ctx_seqlen))

                def inner_loop_body(turn_step, ctx_rnn_state, ctx_rnn_output):
                    # 根据si, ctx_init_state 递归计算多个turn的句子ctx
                    q_antecedent = tf.concat([ctx_rnn_state, dec_rnn_state], axis=-1)  # [batch, h]
                    q_antecedent = tf.tile(tf.expand_dims(q_antecedent, 1), [1, length, 1])  # [batch,len,h]

                    # 抽取每个batch的第i个turn
                    # sent_repre [batch,turn,len,h]
                    q_antecedent = tf.concat([cache['uttn_repre'][:, turn_step, :, :], q_antecedent], -1)  # [batch,len,h] 
                    uttn_mask_in_turn = tf.reshape(uttn_mask, [batch_size, num_turns, length])[:, turn_step, :]  # [batch,len]

                    # word-level-attn
                    h = tf.layers.dense(q_antecedent, 128, activation=tf.nn.tanh, use_bias=True, name='word_level_attn/layer1')
                    energy = tf.layers.dense(h, 1, use_bias=True, name='word_level_attn/layer2')  # [batch,len,1]
                    energy = tf.squeeze(energy, -1) + (1. - uttn_mask_in_turn) * -1e9
                    alpha = tf.nn.softmax(energy)  # [batch,len]
                    r_in_turn = tf.reduce_sum(tf.expand_dims(alpha, -1) * cache['uttn_repre'][:, turn_step, :, :], 1)  # [batch,h]

                    ctx_rnn_output_, ctx_rnn_state = self.encoder_ctx_rnn.one_step(r_in_turn, ctx_rnn_state)
                    # attch
                    ctx_rnn_output = tf.concat([ctx_rnn_output, tf.expand_dims(ctx_rnn_output_, 1)], 1)  # [batch,turn,h]
                    return turn_step + 1, ctx_rnn_state, ctx_rnn_output

                # start inner loop
                final_turn_step, final_state, ctx_rnn_output = tf.while_loop(
                    inner_loop_condition,
                    inner_loop_body,
                    loop_vars=[tf.constant(0, dtype=tf.int32),
                               cache['ctx_rnn_state'],
                               tf.zeros([shape_list(cache['ctx_rnn_state'])[0], 0, self.ctx_enc_hidden_size])],
                    shape_invariants=[
                        tf.TensorShape([]),
                        nest.map_structure(get_state_shape_invariants, init_ctx_encoder_state),
                        tf.TensorShape([None, None, self.ctx_enc_hidden_size]),
                    ])

                # ctx_rnn_output  # [batch,turn,h]
                # dec_rnn_state  # [batch,h]
                # ctx-level-attn
                # q_antecedent = tf.tile(tf.expand_dims(dec_rnn_state, axis=1), [1, num_turns, 1])  # [batch,turn,h]
                # 这样只拿当前batch中的尽可能小的turns数量而不是固定turn
                q_antecedent = tf.tile(tf.expand_dims(dec_rnn_state, axis=1), [1, shape_list(ctx_rnn_output)[1], 1])  # [batch,turn,h]
                q_antecedent = tf.concat([q_antecedent, ctx_rnn_output], 2)  # [batch,turn,h]
                h = tf.layers.dense(q_antecedent, 128, activation=tf.nn.tanh, use_bias=True, name='ctx_level_attn/layer1')
                energy = tf.layers.dense(h, 1, use_bias=True, name='ctx_level_attn/layer2')  # [batch,turn,1]
                energy = tf.squeeze(energy, -1) + (1. - ctx_mask) * -1e9  # [batch,turn]
                alpha = tf.nn.softmax(energy)  # [batch,turn]
                ctx_input_in_dec = tf.reduce_sum(tf.expand_dims(alpha, -1) * ctx_rnn_output, 1)  # [batch,h]

                dec_rnn_input = tf.concat([ctx_input_in_dec, decoder_input], -1)  # [batch,h]
                dec_rnn_output_, dec_rnn_state = self.decoder_rnn.one_step(dec_rnn_input, dec_rnn_state)

                cache['dec_rnn_state'] = dec_rnn_state

                logits = proj_logits(dec_rnn_output_, conf.embed_size, conf.vocab_size, name='share_embedding')

                return logits, cache
예제 #4
0
        def loop_body(time_step, dec_rnn_state, dec_rnn_output):
            # cal decoder ctx
            # word level attention

            # 内层循环是对于解码每一步 ctx_rnn上turn的每一步,最终生成ctx序列 turn_step
            def inner_loop_condition(turn_step, *_):
                return tf.less(turn_step, tf.reduce_max(ctx_seqlen))

            def inner_loop_body(turn_step, ctx_rnn_state, ctx_rnn_output):
                # 根据si, ctx_init_state 递归计算多个turn的句子ctx
                q_antecedent = tf.concat([ctx_rnn_state, dec_rnn_state], axis=-1)  # [batch, h]
                q_antecedent = tf.tile(tf.expand_dims(q_antecedent, 1), [1, length, 1])  # [batch,len,h]

                # 抽取每个batch的第i个turn
                # sent_repre [batch,turn,len,h]
                q_antecedent = tf.concat([uttn_repre[:, turn_step, :, :], q_antecedent], -1)  # [batch,len,h] 
                uttn_mask_in_turn = tf.reshape(uttn_mask, [batch_size, num_turns, length])[:, turn_step, :]  # [batch,len]

                # word-level-attn
                h = tf.layers.dense(q_antecedent, 128, activation=tf.nn.tanh, use_bias=True, name='word_level_attn/layer1')
                energy = tf.layers.dense(h, 1, use_bias=True, name='word_level_attn/layer2')  # [batch,len,1]
                energy = tf.squeeze(energy, -1) + (1. - uttn_mask_in_turn) * -1e9
                alpha = tf.nn.softmax(energy)  # [batch,len]
                r_in_turn = tf.reduce_sum(tf.expand_dims(alpha, -1) * uttn_repre[:, turn_step, :, :], 1)  # [batch,h]

                ctx_rnn_output_, ctx_rnn_state = self.encoder_ctx_rnn.one_step(r_in_turn, ctx_rnn_state)

                # attch
                ctx_rnn_output = tf.concat([ctx_rnn_output, tf.expand_dims(ctx_rnn_output_, 1)], 1)  # [batch,turn,h]
                return turn_step + 1, ctx_rnn_state, ctx_rnn_output

            # start inner loop
            final_turn_step, final_state, ctx_rnn_output = tf.while_loop(
                inner_loop_condition,
                inner_loop_body,
                loop_vars=[tf.constant(0, dtype=tf.int32),
                           init_ctx_encoder_state,
                           tf.zeros([batch_size, 0, self.ctx_enc_hidden_size])],
                shape_invariants=[
                    tf.TensorShape([]),
                    nest.map_structure(get_state_shape_invariants, init_ctx_encoder_state),
                    tf.TensorShape([None, None, self.ctx_enc_hidden_size]),
                ])

            # ctx_rnn_output  # [batch,turn,h]
            # dec_rnn_state  # [batch,h]
            # ctx-level-attn
            # q_antecedent = tf.tile(tf.expand_dims(dec_rnn_state, axis=1), [1, num_turns, 1])  # [batch,turn,h]
            # 这样只拿当前batch中的尽可能小的turns数量而不是固定turn
            q_antecedent = tf.tile(tf.expand_dims(dec_rnn_state, axis=1), [1, shape_list(ctx_rnn_output)[1], 1])  # [batch,turn,h]
            q_antecedent = tf.concat([q_antecedent, ctx_rnn_output], 2)  # [batch,turn,h]
            h = tf.layers.dense(q_antecedent, 128, activation=tf.nn.tanh, use_bias=True, name='ctx_level_attn/layer1')
            energy = tf.layers.dense(h, 1, use_bias=True, name='ctx_level_attn/layer2')  # [batch,turn,1]
            energy = tf.squeeze(energy, -1) + (1. - ctx_mask) * -1e9  # [batch,turn]
            alpha = tf.nn.softmax(energy)  # [batch,turn]
            ctx_input_in_dec = tf.reduce_sum(tf.expand_dims(alpha, -1) * ctx_rnn_output, 1)  # [batch,h]

            dec_rnn_input = tf.concat([ctx_input_in_dec, decoder_input[:, time_step, :]], -1)  # [batch,h]
            dec_rnn_output_, dec_rnn_state = self.decoder_rnn.one_step(dec_rnn_input, dec_rnn_state)

            dec_rnn_output = tf.concat([dec_rnn_output, tf.expand_dims(dec_rnn_output_, 1)], 1)

            return time_step + 1, dec_rnn_state, dec_rnn_output
예제 #5
0
    def build_model2(self):
        self.uttn_enc_hidden_size = 256
        self.ctx_enc_hidden_size = 256
        batch_size, num_turns, length = shape_list(self.multi_s1)

        # embedding
        # [batch,len,turn,embed]
        multi_s1_embed, _ = embedding(tf.expand_dims(self.multi_s1, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=conf.pretrain_emb)
        # [batch,len,embed]
        s2_embed, _ = embedding(tf.expand_dims(self.s2, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=conf.pretrain_emb)

        # uttn encoder
        uttn_input = tf.reshape(multi_s1_embed, [-1, length, conf.embed_size])  # [batch*turn,len,embed]
        uttn_mask = mask_nonpad_from_embedding(uttn_input)  # [batch*turn,len] 1 for nonpad; 0 for pad
        uttn_seqlen = tf.cast(tf.reduce_sum(uttn_mask, axis=-1), tf.int32)  # [batch*turn]
        # uttn-gru
        self.encoder_uttn_rnn = Bi_RNN(cell_name='GRUCell', name='uttn_enc', hidden_size=self.uttn_enc_hidden_size, dropout_rate=self.dropout_rate)
        uttn_repre, uttn_embed = self.encoder_uttn_rnn(uttn_input, uttn_seqlen)  # [batch*turn,len,2hid] [batch*turn,2hid]
        uttn_embed = tf.reshape(uttn_embed, [batch_size, num_turns, self.uttn_enc_hidden_size * 2])  # [batch,turn,2hid]
        uttn_repre = tf.reshape(uttn_repre, [batch_size, num_turns, length, self.uttn_enc_hidden_size * 2])  # [batch,turn,len,2hid]

        # ctx encoder
        ctx_mask = mask_nonpad_from_embedding(uttn_embed)  # [batch,turn] 1 for nonpad; 0 for pad
        ctx_seqlen = tf.cast(tf.reduce_sum(ctx_mask, axis=-1), tf.int32)  # [batch]

        # reverse turn
        uttn_repre = tf.reverse_sequence(uttn_repre, seq_lengths=ctx_seqlen, seq_axis=1, batch_axis=0)

        # ctx-gru
        self.encoder_ctx_rnn = RNN(cell_name='GRUCell', name='ctx_enc', hidden_size=self.ctx_enc_hidden_size, dropout_rate=self.dropout_rate)
        init_ctx_encoder_state = self.encoder_ctx_rnn.cell.zero_state(batch_size, tf.float32)

        # rnn decoder train
        s2_mask = mask_nonpad_from_embedding(s2_embed)  # [batch,len] 1 for nonpad; 0 for pad
        s2_seqlen = tf.cast(tf.reduce_sum(s2_mask, axis=-1), tf.int32)  # [batch]

        decoder_input = shift_right(s2_embed)  # 用pad当做eos
        decoder_input = tf.layers.dropout(decoder_input, rate=self.dropout_rate)  # dropout

        self.decoder_rnn = RNN(cell_name='GRUCell', name='dec', hidden_size=conf.embed_size, dropout_rate=self.dropout_rate)
        init_decoder_state = self.decoder_rnn.cell.zero_state(batch_size, tf.float32)

        # 两重循环 -_-
        # 外层循环是decoder解码的每一步 time_step
        def loop_condition(time_step, *_):
            return tf.less(time_step, tf.reduce_max(s2_seqlen))

        def loop_body(time_step, dec_rnn_state, dec_rnn_output):
            # cal decoder ctx
            # word level attention

            # 内层循环是对于解码每一步 ctx_rnn上turn的每一步,最终生成ctx序列 turn_step
            def inner_loop_condition(turn_step, *_):
                return tf.less(turn_step, tf.reduce_max(ctx_seqlen))

            def inner_loop_body(turn_step, ctx_rnn_state, ctx_rnn_output):
                # 根据si, ctx_init_state 递归计算多个turn的句子ctx
                q_antecedent = tf.concat([ctx_rnn_state, dec_rnn_state], axis=-1)  # [batch, h]
                q_antecedent = tf.tile(tf.expand_dims(q_antecedent, 1), [1, length, 1])  # [batch,len,h]

                # 抽取每个batch的第i个turn
                # sent_repre [batch,turn,len,h]
                q_antecedent = tf.concat([uttn_repre[:, turn_step, :, :], q_antecedent], -1)  # [batch,len,h] 
                uttn_mask_in_turn = tf.reshape(uttn_mask, [batch_size, num_turns, length])[:, turn_step, :]  # [batch,len]

                # word-level-attn
                h = tf.layers.dense(q_antecedent, 128, activation=tf.nn.tanh, use_bias=True, name='word_level_attn/layer1')
                energy = tf.layers.dense(h, 1, use_bias=True, name='word_level_attn/layer2')  # [batch,len,1]
                energy = tf.squeeze(energy, -1) + (1. - uttn_mask_in_turn) * -1e9
                alpha = tf.nn.softmax(energy)  # [batch,len]
                r_in_turn = tf.reduce_sum(tf.expand_dims(alpha, -1) * uttn_repre[:, turn_step, :, :], 1)  # [batch,h]

                ctx_rnn_output_, ctx_rnn_state = self.encoder_ctx_rnn.one_step(r_in_turn, ctx_rnn_state)

                # attch
                ctx_rnn_output = tf.concat([ctx_rnn_output, tf.expand_dims(ctx_rnn_output_, 1)], 1)  # [batch,turn,h]
                return turn_step + 1, ctx_rnn_state, ctx_rnn_output

            # start inner loop
            final_turn_step, final_state, ctx_rnn_output = tf.while_loop(
                inner_loop_condition,
                inner_loop_body,
                loop_vars=[tf.constant(0, dtype=tf.int32),
                           init_ctx_encoder_state,
                           tf.zeros([batch_size, 0, self.ctx_enc_hidden_size])],
                shape_invariants=[
                    tf.TensorShape([]),
                    nest.map_structure(get_state_shape_invariants, init_ctx_encoder_state),
                    tf.TensorShape([None, None, self.ctx_enc_hidden_size]),
                ])

            # ctx_rnn_output  # [batch,turn,h]
            # dec_rnn_state  # [batch,h]
            # ctx-level-attn
            # q_antecedent = tf.tile(tf.expand_dims(dec_rnn_state, axis=1), [1, num_turns, 1])  # [batch,turn,h]
            # 这样只拿当前batch中的尽可能小的turns数量而不是固定turn
            q_antecedent = tf.tile(tf.expand_dims(dec_rnn_state, axis=1), [1, shape_list(ctx_rnn_output)[1], 1])  # [batch,turn,h]
            q_antecedent = tf.concat([q_antecedent, ctx_rnn_output], 2)  # [batch,turn,h]
            h = tf.layers.dense(q_antecedent, 128, activation=tf.nn.tanh, use_bias=True, name='ctx_level_attn/layer1')
            energy = tf.layers.dense(h, 1, use_bias=True, name='ctx_level_attn/layer2')  # [batch,turn,1]
            energy = tf.squeeze(energy, -1) + (1. - ctx_mask) * -1e9  # [batch,turn]
            alpha = tf.nn.softmax(energy)  # [batch,turn]
            ctx_input_in_dec = tf.reduce_sum(tf.expand_dims(alpha, -1) * ctx_rnn_output, 1)  # [batch,h]

            dec_rnn_input = tf.concat([ctx_input_in_dec, decoder_input[:, time_step, :]], -1)  # [batch,h]
            dec_rnn_output_, dec_rnn_state = self.decoder_rnn.one_step(dec_rnn_input, dec_rnn_state)

            dec_rnn_output = tf.concat([dec_rnn_output, tf.expand_dims(dec_rnn_output_, 1)], 1)

            return time_step + 1, dec_rnn_state, dec_rnn_output

        # start outer loop
        final_time_step, final_state, dec_rnn_output = tf.while_loop(
            loop_condition,
            loop_body,
            loop_vars=[tf.constant(0, dtype=tf.int32),
                       init_decoder_state,
                       tf.zeros([batch_size, 0, conf.embed_size])],
            shape_invariants=[
                tf.TensorShape([]),
                nest.map_structure(get_state_shape_invariants, init_decoder_state),
                tf.TensorShape([None, None, conf.embed_size]),
            ])

        decoder_output = dec_rnn_output

        logits = proj_logits(decoder_output, conf.embed_size, conf.vocab_size, name='share_embedding')

        onehot_s2 = tf.one_hot(self.s2, depth=conf.vocab_size)  # [batch,len,vocab]

        xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=onehot_s2)  # [batch,len]
        weights = tf.to_float(tf.not_equal(self.s2, 0))  # [batch,len] 1 for nonpad; 0 for pad

        loss_num = xentropy * weights  # [batch,len]
        loss_den = weights  # [batch,len]

        loss = tf.reduce_sum(loss_num) / tf.reduce_sum(loss_den)  # scalar
        self.loss = loss

        # rnn decoder infer
        # 放在cache里面的在后面symbols_to_logits_fn函数中都会变成batch * beam
        cache = {'dec_rnn_state': self.decoder_rnn.cell.zero_state(batch_size, tf.float32),  # [batch,hid]
                 'ctx_rnn_state': self.encoder_ctx_rnn.cell.zero_state(batch_size, tf.float32),  # [batch,hid]
                 'uttn_repre': uttn_repre  # [batch,turn,len,2hid]
                 }

        def symbols_to_logits_fn(ids, i, cache):
            # ids [batch,length]
            pred_target = ids[:, -1:]  # [batch,1] 截取最后一个
            target_embed, _ = embedding(tf.expand_dims(pred_target, axis=-1), conf.vocab_size, conf.embed_size, 'share_embedding')  # [batch,1,embed]
            decoder_input = tf.squeeze(target_embed, axis=1)  # [batch,embed]

            dec_rnn_state = cache['dec_rnn_state']

            with tf.variable_scope('', reuse=tf.AUTO_REUSE):
                # 内层循环是对于解码每一步 ctx_rnn上turn的每一步,最终生成ctx序列 turn_step
                def inner_loop_condition(turn_step, *_):
                    return tf.less(turn_step, tf.reduce_max(ctx_seqlen))

                def inner_loop_body(turn_step, ctx_rnn_state, ctx_rnn_output):
                    # 根据si, ctx_init_state 递归计算多个turn的句子ctx
                    q_antecedent = tf.concat([ctx_rnn_state, dec_rnn_state], axis=-1)  # [batch, h]
                    q_antecedent = tf.tile(tf.expand_dims(q_antecedent, 1), [1, length, 1])  # [batch,len,h]

                    # 抽取每个batch的第i个turn
                    # sent_repre [batch,turn,len,h]
                    q_antecedent = tf.concat([cache['uttn_repre'][:, turn_step, :, :], q_antecedent], -1)  # [batch,len,h] 
                    uttn_mask_in_turn = tf.reshape(uttn_mask, [batch_size, num_turns, length])[:, turn_step, :]  # [batch,len]

                    # word-level-attn
                    h = tf.layers.dense(q_antecedent, 128, activation=tf.nn.tanh, use_bias=True, name='word_level_attn/layer1')
                    energy = tf.layers.dense(h, 1, use_bias=True, name='word_level_attn/layer2')  # [batch,len,1]
                    energy = tf.squeeze(energy, -1) + (1. - uttn_mask_in_turn) * -1e9
                    alpha = tf.nn.softmax(energy)  # [batch,len]
                    r_in_turn = tf.reduce_sum(tf.expand_dims(alpha, -1) * cache['uttn_repre'][:, turn_step, :, :], 1)  # [batch,h]

                    ctx_rnn_output_, ctx_rnn_state = self.encoder_ctx_rnn.one_step(r_in_turn, ctx_rnn_state)
                    # attch
                    ctx_rnn_output = tf.concat([ctx_rnn_output, tf.expand_dims(ctx_rnn_output_, 1)], 1)  # [batch,turn,h]
                    return turn_step + 1, ctx_rnn_state, ctx_rnn_output

                # start inner loop
                final_turn_step, final_state, ctx_rnn_output = tf.while_loop(
                    inner_loop_condition,
                    inner_loop_body,
                    loop_vars=[tf.constant(0, dtype=tf.int32),
                               cache['ctx_rnn_state'],
                               tf.zeros([shape_list(cache['ctx_rnn_state'])[0], 0, self.ctx_enc_hidden_size])],
                    shape_invariants=[
                        tf.TensorShape([]),
                        nest.map_structure(get_state_shape_invariants, init_ctx_encoder_state),
                        tf.TensorShape([None, None, self.ctx_enc_hidden_size]),
                    ])

                # ctx_rnn_output  # [batch,turn,h]
                # dec_rnn_state  # [batch,h]
                # ctx-level-attn
                # q_antecedent = tf.tile(tf.expand_dims(dec_rnn_state, axis=1), [1, num_turns, 1])  # [batch,turn,h]
                # 这样只拿当前batch中的尽可能小的turns数量而不是固定turn
                q_antecedent = tf.tile(tf.expand_dims(dec_rnn_state, axis=1), [1, shape_list(ctx_rnn_output)[1], 1])  # [batch,turn,h]
                q_antecedent = tf.concat([q_antecedent, ctx_rnn_output], 2)  # [batch,turn,h]
                h = tf.layers.dense(q_antecedent, 128, activation=tf.nn.tanh, use_bias=True, name='ctx_level_attn/layer1')
                energy = tf.layers.dense(h, 1, use_bias=True, name='ctx_level_attn/layer2')  # [batch,turn,1]
                energy = tf.squeeze(energy, -1) + (1. - ctx_mask) * -1e9  # [batch,turn]
                alpha = tf.nn.softmax(energy)  # [batch,turn]
                ctx_input_in_dec = tf.reduce_sum(tf.expand_dims(alpha, -1) * ctx_rnn_output, 1)  # [batch,h]

                dec_rnn_input = tf.concat([ctx_input_in_dec, decoder_input], -1)  # [batch,h]
                dec_rnn_output_, dec_rnn_state = self.decoder_rnn.one_step(dec_rnn_input, dec_rnn_state)

                cache['dec_rnn_state'] = dec_rnn_state

                logits = proj_logits(dec_rnn_output_, conf.embed_size, conf.vocab_size, name='share_embedding')

                return logits, cache

        initial_ids = tf.zeros([batch_size], dtype=tf.int32)  # <pad>为<sos>

        def greedy_search_wrapper():
            """ Greedy Search """
            decoded_ids, scores = greedy_search(
                symbols_to_logits_fn,
                initial_ids,
                conf.max_decode_len,
                cache=cache,
                eos_id=conf.eos_id,
            )
            return decoded_ids, scores

        def beam_search_wrapper():
            """ Beam Search """
            decoded_ids, scores = beam_search(  # [batch,beam,len] [batch,beam]
                symbols_to_logits_fn,
                initial_ids,
                conf.beam_size,
                conf.max_decode_len,
                conf.vocab_size,
                alpha=0,
                states=cache,
                eos_id=conf.eos_id,
            )
            return decoded_ids, scores

        decoded_ids, scores = tf.cond(tf.equal(conf.beam_size, 1), greedy_search_wrapper, beam_search_wrapper)

        self.decoded_ids = tf.identity(decoded_ids, name='decoded_ids')  # [batch,beam/1,len]
        self.scores = tf.identity(scores, name='scores')  # [batch,beam/1]

        self.global_step = tf.train.get_or_create_global_step()
        self.optimizer = tf.train.AdamOptimizer(learning_rate=conf.lr)
        self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
예제 #6
0
파일: s2s_model.py 프로젝트: sjx0451/QizNLP
    def build_model3(self):
        # biGRU encoder + bah_attn + GRU decoder
        # embedding
        # [batch,len,embed]
        # pretrained_word_embeddings = np.load(f'{curr_dir}/pretrain_emb_300.npy')
        pretrained_word_embeddings = None
        s1_embed, _ = embedding(tf.expand_dims(self.s1, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=pretrained_word_embeddings)
        s1_mask = mask_nonpad_from_embedding(s1_embed)  # [batch,len] 1 for nonpad; 0 for pad
        s1_seqlen = tf.cast(tf.reduce_sum(s1_mask, axis=-1), tf.int32)  # [batch]

        # encoder
        encoder_input = s1_embed
        encoder_input = tf.layers.dropout(encoder_input, rate=self.dropout_rate)  # dropout

        with tf.variable_scope('birnn_encoder'):
            self.bilstm_encoder1 = Bi_RNN(cell_name='GRUCell', hidden_size=conf.hidden_size, dropout_rate=self.dropout_rate)
            encoder_output, _ = self.bilstm_encoder1(encoder_input, s1_seqlen)  # [batch,len,2hid]

        batch_size = shape_list(encoder_input)[0]

        # decoder
        decoder_rnn = getattr(tf.nn.rnn_cell, 'GRUCell')(conf.hidden_size)  # GRUCell/LSTMCell
        encdec_atten = EncDecAttention(encoder_output, s1_seqlen, conf.hidden_size)

        s2_embed, _ = embedding(tf.expand_dims(self.s2, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=pretrained_word_embeddings)
        s2_mask = mask_nonpad_from_embedding(s2_embed)  # [batch,len] 1 for nonpad; 0 for pad
        s2_seqlen = tf.cast(tf.reduce_sum(s2_mask, -1), tf.int32)  # [batch]

        decoder_input = s2_embed
        decoder_input = shift_right(decoder_input)  # 用pad当做eos
        decoder_input = tf.layers.dropout(decoder_input, rate=self.dropout_rate)  # dropout

        init_decoder_state = decoder_rnn.zero_state(batch_size, tf.float32)
        # init_decoder_state = tf.nn.rnn_cell.LSTMStateTuple(c, h)

        time_step = tf.constant(0, dtype=tf.int32)
        rnn_output = tf.zeros([batch_size, 0, conf.hidden_size])
        context_output = tf.zeros([batch_size, 0, conf.hidden_size * 2])  # 注意力

        def loop_condition(time_step, *_):
            return tf.less(time_step, tf.reduce_max(s2_seqlen))

        def loop_body(time_step, prev_rnn_state, rnn_output, context_output):
            # attention
            s = prev_rnn_state if isinstance(decoder_rnn, tf.nn.rnn_cell.GRUCell) else prev_rnn_state.h
            context = encdec_atten(s)  # [batch,hidden]
            context_output = tf.concat([context_output, tf.expand_dims(context, axis=1)], axis=1)

            # construct rnn input
            rnn_input = tf.concat([decoder_input[:, time_step, :], context], axis=-1)  # [batch,hidden+]  use attention
            # rnn_input = decoder_input[:, time_step, :]  # [batch,hidden]  not use attention

            # run rnn
            current_output, rnn_state = decoder_rnn(rnn_input, prev_rnn_state)

            # append to output bucket via length dim
            rnn_output = tf.concat([rnn_output, tf.expand_dims(current_output, axis=1)], axis=1)

            return time_step + 1, rnn_state, rnn_output, context_output

        # start loop
        final_time_step, final_state, rnn_output, context_output = tf.while_loop(
            loop_condition,
            loop_body,
            loop_vars=[time_step, init_decoder_state, rnn_output, context_output],
            shape_invariants=[
                tf.TensorShape([]),
                nest.map_structure(get_state_shape_invariants, init_decoder_state),
                tf.TensorShape([None, None, conf.hidden_size]),
                tf.TensorShape([None, None, conf.hidden_size * 2]),
            ])
        # body_output = tf.concat([rnn_output, context_output], axis=-1)
        # body_output = tf.layers.dense(body_output, self.hidden_size, activation=tf.nn.tanh, use_bias=True, name='body_output_layer')
        decoder_output = rnn_output

        logits = proj_logits(decoder_output, conf.embed_size, conf.vocab_size, name='share_embedding')
        # logits = proj_logits(encoder_output[:,:,:300], conf.embed_size, conf.vocab_size, name='share_embedding')

        onehot_s2 = tf.one_hot(self.s2, depth=conf.vocab_size)  # [batch,len,vocab]

        xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=onehot_s2)  # [batch,len]
        weights = tf.to_float(tf.not_equal(self.s2, 0))  # [batch,len] 1 for nonpad; 0 for pad

        loss_num = xentropy * weights  # [batch,len]
        loss_den = weights  # [batch,len]

        loss = tf.reduce_sum(loss_num) / tf.reduce_sum(loss_den)  # scalar
        self.loss = loss
        # self.sent_loss = tf.reduce_sum(loss_num, -1) / tf.reduce_sum(loss_den, -1)  # [batch]


        # decoder infer
        cache = {'state': decoder_rnn.zero_state(batch_size, tf.float32)}
        def symbols_to_logits_fn(ids, i, cache):
            # ids [batch,length]
            pred_target = ids[:, -1:]  # 截取最后一个  [batch,1]
            embed_target, _ = embedding(tf.expand_dims(pred_target, axis=-1), conf.vocab_size, conf.embed_size, 'share_embedding')  # [batch,length,embed]
            decoder_input = tf.squeeze(embed_target, axis=1)  # [batch,embed]

            # if use attention
            s = cache['state'] if isinstance(decoder_rnn, tf.nn.rnn_cell.GRUCell) else cache['state'].h
            context = encdec_atten(s, beam_size=conf.beam_size)  # [batch,hidden]
            decoder_input = tf.concat([decoder_input, context], axis=-1)  # [batch,hidden+]

            # run rnn
            # with tf.variable_scope('rnn', reuse=tf.AUTO_REUSE):
            decoder_output, cache['state'] = decoder_rnn(decoder_input, cache['state'])

            logits = proj_logits(decoder_output, conf.hidden_size, conf.vocab_size, name='share_embedding')

            return logits, cache

        initial_ids = tf.zeros([batch_size], dtype=tf.int32)  # <pad>为<sos>

        def greedy_search_wrapper():
            """ Greedy Search """
            decoded_ids, scores = greedy_search(
                symbols_to_logits_fn,
                initial_ids,
                max_decode_len=conf.max_decode_len,
                cache=cache,
                eos_id=conf.eos_id,
            )
            return decoded_ids, scores

        def beam_search_wrapper():
            """ Beam Search """
            decoded_ids, scores = beam_search(  # [batch,beam,len] [batch,beam]
                symbols_to_logits_fn,
                initial_ids,
                beam_size=conf.beam_size,
                max_decode_len=conf.max_decode_len,
                vocab_size=conf.vocab_size,
                states=cache,
                eos_id=conf.eos_id,
                gamma=conf.gamma,
                num_group=conf.num_group,
                top_k=conf.top_k,
            )
            return decoded_ids, scores

        decoded_ids, scores = tf.cond(tf.equal(conf.beam_size, 1), greedy_search_wrapper, beam_search_wrapper)

        self.decoded_ids = tf.identity(decoded_ids, name='decoded_ids')  # [batch,beam/1,len]
        self.scores = tf.identity(scores, name='scores')  # [batch,beam/1]

        self.global_step = tf.train.get_or_create_global_step()
        self.optimizer = tf.train.AdamOptimizer(learning_rate=conf.lr)
        self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
예제 #7
0
파일: s2s_model.py 프로젝트: sjx0451/QizNLP
    def build_model2(self):
        # biGRU encoder + bah_attn + GRU decoder
        # embedding
        # [batch,len,embed]
        # pretrained_word_embeddings = np.load(f'{curr_dir}/pretrain_emb_300.npy')
        pretrained_word_embeddings = None
        s1_embed, _ = embedding(tf.expand_dims(self.s1, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=pretrained_word_embeddings)
        s1_mask = mask_nonpad_from_embedding(s1_embed)  # [batch,len] 1 for nonpad; 0 for pad
        s1_seqlen = tf.cast(tf.reduce_sum(s1_mask, axis=-1), tf.int32)  # [batch]

        # encoder
        encoder_input = s1_embed
        encoder_input = tf.layers.dropout(encoder_input, rate=self.dropout_rate)  # dropout

        with tf.variable_scope('birnn_encoder'):
            self.bilstm_encoder1 = Bi_RNN(cell_name='GRUCell', hidden_size=conf.hidden_size, dropout_rate=self.dropout_rate)
            encoder_output, _ = self.bilstm_encoder1(encoder_input, s1_seqlen)  # [batch,len,2hid]

        # decoder

        s2_embed, _ = embedding(tf.expand_dims(self.s2, -1), conf.vocab_size, conf.embed_size, name='share_embedding', pretrain_embedding=pretrained_word_embeddings)
        s2_mask = mask_nonpad_from_embedding(s2_embed)  # [batch,len] 1 for nonpad; 0 for pad
        s2_seqlen = tf.cast(tf.reduce_sum(s2_mask, -1), tf.int32)  # [batch]

        decoder_input = s2_embed
        decoder_input = shift_right(decoder_input)  # 用pad当做eos
        decoder_input = tf.layers.dropout(decoder_input, rate=self.dropout_rate)  # dropout

        decoder_rnn = tf.nn.rnn_cell.DropoutWrapper(getattr(tf.nn.rnn_cell, 'GRUCell')(conf.hidden_size),  # GRUCell/LSTMCell
                                                    input_keep_prob=1.0 - self.dropout_rate)

        attention_mechanism = getattr(tf.contrib.seq2seq, 'BahdanauAttention')(
            conf.hidden_size,
            encoder_output,
            memory_sequence_length=s1_seqlen,
            name='BahdanauAttention',
        )
        cell = tf.contrib.seq2seq.AttentionWrapper(decoder_rnn,
                                                   attention_mechanism,
                                                   output_attention=False,
                                                   name='attention_wrapper',
                                                   )

        with tf.variable_scope('decoder'):
            decoder_output, _ = tf.nn.dynamic_rnn(
                cell,
                decoder_input,
                s2_seqlen,
                initial_state=None,  # 默认用0向量初始化
                dtype=tf.float32,
                time_major=False
            )  # 默认scope是rnn e.g.decoder/rnn/kernal

        logits = proj_logits(decoder_output, conf.embed_size, conf.vocab_size, name='share_embedding')
        # logits = proj_logits(encoder_output[:,:,:300], conf.embed_size, conf.vocab_size, name='share_embedding')

        onehot_s2 = tf.one_hot(self.s2, depth=conf.vocab_size)  # [batch,len,vocab]

        xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=onehot_s2)  # [batch,len]
        weights = tf.to_float(tf.not_equal(self.s2, 0))  # [batch,len] 1 for nonpad; 0 for pad

        loss_num = xentropy * weights  # [batch,len]
        loss_den = weights  # [batch,len]

        loss = tf.reduce_sum(loss_num) / tf.reduce_sum(loss_den)  # scalar
        self.loss = loss
        # self.sent_loss = tf.reduce_sum(loss_num, -1) / tf.reduce_sum(loss_den, -1)  # [batch]

        """ 
        decoder infer 
        """
        # batch_size = shape_list(encoder_output)[0]
        batch_size = shape_list(encoder_output)[0]
        last_dim = shape_list(encoder_output)[-1]
        tile_encoder_output = tf.tile(tf.expand_dims(encoder_output, 1), [1, conf.beam_size, 1, 1])
        tile_encoder_output = tf.reshape(tile_encoder_output, [batch_size * conf.beam_size, -1, last_dim])
        tile_s1_seqlen = tf.tile(tf.expand_dims(s1_seqlen, 1), [1, conf.beam_size])
        tile_s1_seqlent = tf.reshape(tile_s1_seqlen, [-1])

        # 因为tf.BahdanauAttention在初始时就要指定beam_size来tile memory, 所以这里初始化另外一个专用于推断的beam_size tiled的并共享参数
        # 不过验证还是有问题,不能用reuse=True 报 Variable memory_layer_1/kernel does not exist, or was not created with tf.get_variable()
        # 故不建议使用
        with tf.variable_scope('', reuse=tf.AUTO_REUSE):
            attention_mechanism_decoder = getattr(tf.contrib.seq2seq, 'BahdanauAttention')(
                conf.hidden_size,
                tile_encoder_output,
                memory_sequence_length=tile_s1_seqlent,
                name='BahdanauAttention',
            )
            cell_decoder = tf.contrib.seq2seq.AttentionWrapper(decoder_rnn,
                                                               attention_mechanism_decoder,
                                                               output_attention=False,
                                                               name='attention_wrapper',
                                                               )

        initial_state = cell_decoder.zero_state(batch_size * conf.beam_size, tf.float32)  # 内部会检查batch_size与encoder_output是否一致,需乘beam_size

        # 初始化缓存
        # 区分能否设在cache: cache的值在beam_search过程中会expand和merge,需要tensor rank大于1
        cache = {
            'cell_state': initial_state.cell_state,
            'attention': initial_state.attention,
            'alignments': initial_state.alignments,
            'attention_state': initial_state.attention_state,
        }
        unable_cache = {
            'alignment_history': initial_state.alignment_history,
            # 'time': initial_state.time
        }

        # 将cache先变回batch,beam_search过程会expand/merge/gather,使得state是符合batch*beam的
        cache = nest.map_structure(lambda s: s[:batch_size], cache)

        def symbols_to_logits_fn(ids, i, cache):
            nonlocal unable_cache
            ids = ids[:, -1:]
            target = tf.expand_dims(ids, axis=-1)  # [batch,1,1]
            embedding_target, _ = embedding(target, conf.vocab_size, conf.hidden_size, 'share_embedding', reuse=True)
            input = tf.squeeze(embedding_target, axis=1)  # [batch,hid]

            # 合并 cache和unable_cache为state
            state = cell_decoder.zero_state(batch_size * conf.beam_size, tf.float32).clone(
                cell_state=cache['cell_state'],
                attention=cache['attention'],
                alignments=cache['alignments'],
                attention_state=cache['attention_state'],
                alignment_history=unable_cache['alignment_history'],
                # time=unable_cache['time'],
                time=tf.convert_to_tensor(i, dtype=tf.int32),
            )

            with tf.variable_scope('decoder/rnn', reuse=tf.AUTO_REUSE):
                output, state = cell_decoder(input, state)
            # 分开cache和unable_cache
            cache['cell_state'] = state.cell_state
            cache['attention'] = state.attention
            cache['alignments'] = state.alignments
            cache['attention_state'] = state.attention_state
            unable_cache['alignment_history'] = state.alignment_history
            # unable_cache['time'] = state.time
            body_output = output  # [batch,hidden]

            logits = proj_logits(body_output, conf.embed_size, conf.vocab_size, name='share_embedding')
            return logits, cache

        initial_ids = tf.zeros([batch_size], dtype=tf.int32)  # <pad>为<sos>

        def greedy_search_wrapper():
            """ Greedy Search """
            decoded_ids, scores = greedy_search(
                symbols_to_logits_fn,
                initial_ids,
                max_decode_len=conf.max_decode_len,
                cache=cache,
                eos_id=conf.eos_id,
            )
            return decoded_ids, scores

        def beam_search_wrapper():
            """ Beam Search """
            decoded_ids, scores = beam_search(  # [batch,beam,len] [batch,beam]
                symbols_to_logits_fn,
                initial_ids,
                conf.beam_size,
                conf.max_decode_len,
                conf.vocab_size,
                states=cache,
                eos_id=conf.eos_id,
                gamma=conf.gamma,
                num_group=conf.num_group,
                top_k=conf.top_k,
            )
            return decoded_ids, scores

        decoded_ids, scores = tf.cond(tf.equal(conf.beam_size, 1), greedy_search_wrapper, beam_search_wrapper)

        self.decoded_ids = tf.identity(decoded_ids, name='decoded_ids')  # [batch,beam/1,len]
        self.scores = tf.identity(scores, name='scores')  # [batch,beam/1]

        self.global_step = tf.train.get_or_create_global_step()
        self.optimizer = tf.train.AdamOptimizer(learning_rate=conf.lr)
        self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)