예제 #1
0
 def testLuongMonotonicHard(self):
   # Run attention mechanism with mode='hard', make sure probabilities are hard
   b, t, u, d = 10, 20, 30, 40
   with self.test_session(use_gpu=True) as sess:
     a = wrapper.LuongMonotonicAttention(
         d,
         random_ops.random_normal((b, t, u)),
         mode='hard')
     # Just feed previous attention as [1, 0, 0, ...]
     attn, unused_state = a(
         random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
     sess.run(variables.global_variables_initializer())
     attn_out = attn.eval()
     # All values should be 0 or 1
     self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1)))
     # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0)
     self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1,
                                          attn_out.sum(axis=1) == 0)))
예제 #2
0
def _attention_decoder_wrapper(batch_size, num_units, memory, mutli_layer, dtype=dtypes.float32 ,\
                               attention_layer_size=None, cell_input_fn=None, attention_type='B',\
                               probability_fn=None, alignment_history=False, output_attention=True, \
                               initial_cell_state=None, normalization=False, sigmoid_noise=0.,
                               sigmoid_noise_seed=None, score_bias_init=0.):
    """
    A wrapper for rnn-decoder with attention mechanism

    the detail about params explanation can be found at :
        blog.csdn.net/qsczse943062710/article/details/79539005

    :param mutli_layer: a object returned by function _mutli_layer_rnn()

    :param attention_type, string
        'B' is for BahdanauAttention as described in:

          Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
          "Neural Machine Translation by Jointly Learning to Align and Translate."
          ICLR 2015. https://arxiv.org/abs/1409.0473

        'L' is for LuongAttention as described in:

            Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
            "Effective Approaches to Attention-based Neural Machine Translation."
            EMNLP 2015.  https://arxiv.org/abs/1508.04025

        MonotonicAttention is described in :

            Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
            "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
            ICML 2017.  https://arxiv.org/abs/1704.00784

        'BM' :  Monotonic attention mechanism with Bahadanau-style energy function

        'LM' :  Monotonic attention mechanism with Luong-style energy function


        or maybe something user defined in the future
        **warning** :

            if normalization is set True,
            then normalization will be applied to all types of attentions as described in:
                Tim Salimans, Diederik P. Kingma.
                "Weight Normalization: A Simple Reparameterization to Accelerate
                Training of Deep Neural Networks."
                https://arxiv.org/abs/1602.07868

    A example usage:
        att_wrapper, states = _attention_decoder_wrapper(*args)
        while decoding:
            output, states = att_wrapper(input, states)
            ...
            some processing on output
            ...
            input = processed_output
    """

    if attention_type == 'B':
        attention_mechanism = att_w.BahdanauAttention(
            num_units=num_units,
            memory=memory,
            probability_fn=probability_fn,
            normalize=normalization)
    elif attention_type == 'BM':
        attention_mechanism = att_w.BahdanauMonotonicAttention(
            num_units=num_units,
            memory=memory,
            normalize=normalization,
            sigmoid_noise=sigmoid_noise,
            sigmoid_noise_seed=sigmoid_noise_seed,
            score_bias_init=score_bias_init)
    elif attention_type == 'L':
        attention_mechanism = att_w.LuongAttention(
            num_units=num_units,
            memory=memory,
            probability_fn=probability_fn,
            scale=normalization)
    elif attention_type == 'LM':
        attention_mechanism = att_w.LuongMonotonicAttention(
            num_units=num_units,
            memory=memory,
            scale=normalization,
            sigmoid_noise=sigmoid_noise,
            sigmoid_noise_seed=sigmoid_noise_seed,
            score_bias_init=score_bias_init)
    else:
        raise 'Invalid attention type'

    att_wrapper = att_w.AttentionWrapper(
        cell=mutli_layer,
        attention_mechanism=attention_mechanism,
        attention_layer_size=attention_layer_size,
        cell_input_fn=cell_input_fn,
        alignment_history=alignment_history,
        output_attention=output_attention,
        initial_cell_state=initial_cell_state)
    init_states = att_wrapper.zero_state(batch_size=batch_size, dtype=dtype)
    return att_wrapper, init_states