Beispiel #1
0
        def func_point_logits(dec_h, enc_e, enc_len):
            '''
            Args:
                dec_h : shape(b_sz, h_dec_sz)
                enc_e : shape(b_sz, tstp_enc, dec_emb_sz)
                enc_len : shape(b_sz,)
            '''

            dec_h_ex = tf.expand_dims(dec_h, dim=1)  # shape(b_sz, 1, h_dec_sz)
            dec_h_ex = tf.tile(
                dec_h_ex, [1, tstp_enc, 1])  # shape(b_sz, tstp_enc, h_dec_sz)
            linear_concat = tf.concat(
                2, [dec_h_ex, enc_e
                    ])  # shape(b_sz, tstp_enc, h_dec_sz+ dec_emb_sz)
            point_linear = TfUtils.last_dim_linear(  # shape(b_sz, tstp_enc, h_dec_sz)
                linear_concat,
                output_size=h_dec_sz,
                bias=False,
                scope='Ptr_W')
            point_v = TfUtils.last_dim_linear(  # shape(b_sz, tstp_enc, 1)
                tf.tanh(point_linear),
                output_size=1,
                bias=False,
                scope='Ptr_V')
            point_logits = tf.squeeze(point_v,
                                      squeeze_dims=[2
                                                    ])  # shape(b_sz, tstp_enc)
            mask = TfUtils.mkMask(enc_len,
                                  maxLen=tstp_enc)  # shape(b_sz, tstp_enc)
            point_logits = tf.select(mask, point_logits,
                                     tf.ones_like(point_logits) *
                                     small_num)  # shape(b_sz, tstp_enc)

            return point_logits
        def func_point_logits(dec_h, enc_ptr, enc_len):
            '''
            Args:
                dec_h : shape(b_sz, tstp_dec, h_dec_sz)
                enc_ptr : shape(b_sz, tstp_dec, tstp_enc, Ptr_sz)
                enc_len : shape(b_sz,)
            '''
            dec_h_ex = tf.expand_dims(
                dec_h, axis=2)  # shape(b_sz, tstp_dec, 1, h_dec_sz)
            dec_h_ex = tf.tile(dec_h_ex,
                               [1, 1, tstp_enc, 1
                                ])  # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz)
            linear_concat = tf.concat(axis=3, values=[
                dec_h_ex, enc_ptr
            ])  # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz+ Ptr_sz)
            point_linear = TfUtils.last_dim_linear(  # shape(b_sz, tstp_dec, tstp_enc, h_dec_sz)
                linear_concat,
                output_size=h_dec_sz,
                bias=False,
                scope='Ptr_W')
            point_v = TfUtils.last_dim_linear(  # shape(b_sz, tstp_dec, tstp_enc, 1)
                tf.tanh(point_linear),
                output_size=1,
                bias=False,
                scope='Ptr_V')

            point_logits = tf.squeeze(
                point_v, axis=[3])  # shape(b_sz, tstp_dec, tstp_enc)

            enc_len = tf.expand_dims(enc_len, 1)  # shape(b_sz, 1)
            enc_len = tf.tile(enc_len, [1, tstp_dec])  # shape(b_sz, tstp_dec)
            mask = TfUtils.mkMask(
                enc_len, maxLen=tstp_enc)  # shape(b_sz, tstp_dec, tstp_enc)
            point_logits = tf.where(
                mask,
                point_logits,  # shape(b_sz, tstp_dec, tstp_enc)
                tf.ones_like(point_logits) * small_num)

            return point_logits