예제 #1
0
    def _build_decoder_cell(self):
        # no beam
        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs
            _input_layer = Dense(self.hidden_units, dtype=self.dtype, name="attn_input_feeding")
            return _input_layer(array_ops.concat([inputs, attention], -1))
        
        # attention mechanism 'luong'
        with tf.variable_scope('shared_attention_mechanism'):
            self.attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \
                                                                        memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)        
        # build decoder cell
        self.init_decoder_cell_list = [self._build_single_cell() for i in range(self.depth)]
        decoder_initial_state = encoder_last_state
        
        self.decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\
            cell = self.init_decoder_cell_list[-1], \
            attention_mechanism=self.attention_mechanism,\
            attention_layer_size=self.hidden_units,\
            cell_input_fn=attn_decoder_input_fn,\
            initial_cell_state=encoder_last_state[-1],\
            alignment_history=False)]
        batch_size = self.batch_size
        initial_state = [state for state in encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)
        
        # beam
        beam_encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width)
        beam_encoder_last_state = nest.map_structure(lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state)
        beam_encoder_inputs_length = seq2seq.tile_batch(self.encoder_inputs_length, multiplier=self.beam_width)

        with tf.variable_scope('shared_attention_mechanism', reuse=True):
            self.beam_attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \
                                                                             memory=beam_encoder_outputs, \
                                                                             memory_sequence_length=beam_encoder_inputs_length)

        beam_decoder_initial_state = beam_encoder_last_state
        self.beam_decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\
            cell = self.init_decoder_cell_list[-1], \
            attention_mechanism=self.beam_attention_mechanism,\
            attention_layer_size=self.hidden_units,\
            cell_input_fn=attn_decoder_input_fn,\
            initial_cell_state=beam_encoder_last_state[-1],\
            alignment_history=False)]
            
        beam_batch_size = self.batch_size * self.beam_width
        beam_initial_state = [state for state in beam_encoder_last_state]
        beam_initial_state[-1] = self.beam_decoder_cell_list[-1].zero_state(batch_size=beam_batch_size, dtype=self.dtype)
        beam_decoder_initial_state = tuple(beam_initial_state)
        
        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state, \
               MultiRNNCell(self.beam_decoder_cell_list), beam_decoder_initial_state
예제 #2
0
    def build_decoder_cell(self):

        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        if self.use_beamsearch_decode:
            print ("use beamsearch decoding..")
            encoder_outputs = seq2seq.tile_batch(
                self.encoder_outputs, multiplier=self.beam_width)
            encoder_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state)
            encoder_inputs_length = seq2seq.tile_batch(
                self.encoder_inputs_length, multiplier=self.beam_width)

        # Building attention mechanism: Default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        self.attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=self.hidden_units, memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length,) 
        # 'Luong' style attention: https://arxiv.org/abs/1508.04025
        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.hidden_units, memory=encoder_outputs, 
                memory_sequence_length=encoder_inputs_length,)
 
        # Building decoder_cell
        self.decoder_cell_list = [
            self.build_single_cell() for i in range(self.depth)]
        decoder_initial_state = encoder_last_state

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs

            # Essential when use_residual=True
            _input_layer = Dense(self.hidden_units, dtype=self.dtype,
                                 name='attn_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # Note: We implement Attention mechanism only on the top decoder layer
        self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=False,
            name='Attention_Wrapper')

        batch_size = self.batch_size if not self.use_beamsearch_decode \
                     else self.batch_size * self.beam_width
        initial_state = [state for state in encoder_last_state]

        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
          batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)

        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
  def testLuongScaledDType(self):
    # Test case for GitHub issue 18099
    for dtype in [np.float16, np.float32, np.float64]:
      num_units = 128
      encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
      encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
      decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      batch_size = 64
      attention_mechanism = wrapper.LuongAttention(
          num_units=num_units,
          memory=encoder_outputs,
          memory_sequence_length=encoder_sequence_length,
          scale=True,
          dtype=dtype,
      )
      cell = rnn_cell.LSTMCell(num_units)
      cell = wrapper.AttentionWrapper(cell, attention_mechanism)

      helper = helper_py.TrainingHelper(decoder_inputs,
                                        decoder_sequence_length)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtype, batch_size=batch_size))

      final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual(final_outputs.rnn_output.dtype, dtype)
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
예제 #4
0
  def _build_attention(self, encoder_outputs, encoder_sequence_length):
    """
    Builds Attention part of the graph.
    Currently supports "bahdanau" and "luong"
    :param encoder_outputs:
    :param encoder_sequence_length:
    :return:
    """
    with tf.variable_scope("Attention"):
      attention_depth = self.model_params['attention_layer_size']
      if self.model_params['attention_type'] == 'bahdanau':
        bah_normalize = self.model_params['bahdanau_normalize'] if 'bahdanau_normalize' in self.model_params else False
        attention_mechanism = attention_wrapper.BahdanauAttention(num_units=attention_depth,
                                             memory=encoder_outputs, normalize = bah_normalize,
                                             memory_sequence_length=encoder_sequence_length,
                                             probability_fn=tf.nn.softmax)
      elif self.model_params['attention_type'] == 'luong':
        luong_scale = self.model_params['luong_scale'] if 'luong_scale' in self.model_params else False
        attention_mechanism = attention_wrapper.LuongAttention(num_units=attention_depth,
                                             memory=encoder_outputs, scale = luong_scale,
                                             memory_sequence_length=encoder_sequence_length,
                                             probability_fn=tf.nn.softmax)
      else:
        raise ValueError('Unknown Attention Type')

      return attention_mechanism
예제 #5
0
    def build_dec_cell(self, hidden_size):
        enc_outputs = self.enc_outputs
        enc_last_state = self.enc_last_state
        enc_inputs_length = self.enc_inp_len

        if self.use_beam_search:
            self.logger.info("using beam search decoding")
            enc_outputs = seq2seq.tile_batch(self.enc_outputs,
                                             multiplier=self.p.beam_width)
            enc_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.p.beam_width),
                self.enc_last_state)
            enc_inputs_length = seq2seq.tile_batch(self.enc_inp_len,
                                                   self.p.beam_width)

        if self.p.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=hidden_size,
                memory=enc_outputs,
                memory_sequence_length=enc_inputs_length)
        else:
            self.attention_mechanism = attention_wrapper.BahdanauAttention(
                num_units=hidden_size,
                memory=enc_outputs,
                memory_sequence_length=enc_inputs_length)

        def attn_dec_input_fn(inputs, attention):
            if not self.p.attn_input_feeding:
                return inputs
            else:
                _input_layer = Dense(hidden_size,
                                     dtype=self.p.dtype,
                                     name='attn_input_feeding')
                return _input_layer(tf.concat([inputs, attention], -1))

        self.dec_cell_list = [
            self.build_single_cell(hidden_size) for _ in range(self.p.depth)
        ]

        if self.p.use_attn:
            self.dec_cell_list[-1] = attention_wrapper.AttentionWrapper(
                cell=self.dec_cell_list[-1],
                attention_mechanism=self.attention_mechanism,
                attention_layer_size=hidden_size,
                cell_input_fn=attn_dec_input_fn,
                initial_cell_state=enc_last_state[-1],
                alignment_history=False,
                name='attention_wrapper')

        batch_size = self.p.batch_size if not self.use_beam_search else self.p.batch_size * self.p.beam_width
        initial_state = [state for state in enc_last_state]
        if self.p.use_attn:
            initial_state[-1] = self.dec_cell_list[-1].zero_state(
                batch_size=batch_size, dtype=self.p.dtype)
        dec_initial_state = tuple(initial_state)

        return MultiRNNCell(self.dec_cell_list), dec_initial_state
예제 #6
0
    def build_decoder_cell(self):
        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length
        # building attention mechanism: default Bahdanau
        # 'Bahdanau': https://arxiv.org/abs/1409.0473
        self.attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=self.hidden_size,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length)
        # 'Luong': https://arxiv.org/abs/1508.04025
        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.hidden_size,
                memory=self.encoder_outputs,
                memory_sequence_length=self.encoder_inputs_length)

        # building decoder_cell
        self.decoder_cell_list = [
            self.build_single_cell() for _ in range(self.layer_num)
        ]

        def att_decoder_input_fn(inputs, attention):
            if not self.use_att_decoding:
                return inputs

            _input_layer = Dense(self.hidden_size,
                                 dtype=self.dtype,
                                 name='att_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], axis=-1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # implement attention mechanism only on the top of decoder layer
        self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_size,
            cell_input_fn=att_decoder_input_fn,
            initial_cell_state=encoder_last_state[
                -1],  # last hidden state of last encode layer
            alignment_history=False,
            name='Attention_Wrapper')
        initial_state = [state for state in encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=self.batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)
        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
    def build_decoder_cell(self):
        self.decoder_cell_list = \
           [self.build_single_cell() for i in range(self.para.num_layers)]

        if self.para.mode == 'train':
            encoder_outputs = self.encoder_outputs
            encoder_inputs_len = self.encoder_inputs_len
            encoder_states = self.encoder_states
            batch_size = self.para.batch_size
        else:
            encoder_outputs = seq2seq.tile_batch(
                self.encoder_outputs, multiplier=self.para.beam_width)
            encoder_inputs_len = seq2seq.tile_batch(
                self.encoder_inputs_len, multiplier=self.para.beam_width)
            encoder_states = seq2seq.tile_batch(
                self.encoder_states, multiplier=self.para.beam_width)
            batch_size = self.para.batch_size * self.para.beam_width

        if self.para.attention_mode == 'luong':
            # scaled luong: recommended by authors of NMT
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.para.num_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_len,
                scale=True)
            output_attention = True
        else:
            self.attention_mechanism = attention_wrapper.BahdanauAttention(
                num_units=self.para.num_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_len)
            output_attention = False

        cell = tf.contrib.rnn.MultiRNNCell(self.decoder_cell_list)
        cell = attention_wrapper.AttentionWrapper(
            cell=cell,
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.para.num_units,
            name='attention')
        decoder_initial_state = cell.zero_state(
            batch_size, self.dtype).clone(cell_state=encoder_states)

        return cell, decoder_initial_state
예제 #8
0
  def build_attention_decoder_cell(self):
    encoder_outputs = self.encoder_outputs
    encoder_last_state = self.encoder_last_state
    encoder_inputs_length = self.encoder_inputs_length

    self.attention_mechanism = attention_wrapper.BahdanauAttention(
      num_units=self.hidden_units, memory=encoder_outputs,
      memory_sequence_length=encoder_inputs_length, )

    if self.attention_type.lower()=='luong':
      self.attention_mechanism = attention_wrapper.LuongAttention(
        num_units=self.hidden_units, memory=encoder_outputs,
        memory_sequence_length=encoder_inputs_length, )

    # # Building decoder_cell
    self.decoder_cell_list = [self.build_single_cell() for i in range(self.num_layers)]

    def attn_decoder_input_fn(inputs, attention):
      if not self.attn_input_feeding:
        return inputs

      # Essential when use_residual=True
      _input_layer = tf.layers.dense(tf.concat([inputs, attention], axis=-1), self.hidden_units,
                                     name='attn_input_feeding')
      return _input_layer

    self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
      cell=self.decoder_cell_list[-1],
      attention_mechanism=self.attention_mechanism,
      attention_layer_size=self.hidden_units,
      cell_input_fn=attn_decoder_input_fn,
      initial_cell_state=encoder_last_state[-1],
      alignment_history=False,
      name='Attention_Wrapper'
      )
    batch_size = self.config['batch_size']
    initial_state = [state for state in encoder_last_state]

    initial_state[-1] = self.decoder_cell_list[-1].zero_state(
      batch_size=batch_size, dtype=self.dtype)
    decoder_initial_state = tuple(initial_state)
    return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
예제 #9
0
  def testAttentionWrapperStateShapePropgation(self):
    batch_size = 5
    max_time = 5
    num_units = 5

    memory = random_ops.random_uniform(
        [batch_size, max_time, num_units], seed=1)
    mechanism = wrapper.LuongAttention(num_units, memory)
    cell = wrapper.AttentionWrapper(rnn_cell.LSTMCell(num_units), mechanism)

    # Create zero state with static batch size.
    static_state = cell.zero_state(batch_size, dtypes.float32)
    # Create zero state without static batch size.
    state = cell.zero_state(array_ops.shape(memory)[0], dtypes.float32)

    state = static_state.clone(
        cell_state=state.cell_state, attention=state.attention)

    self.assertEqual(state.cell_state.c.shape, static_state.cell_state.c.shape)
    self.assertEqual(state.cell_state.h.shape, static_state.cell_state.h.shape)
    self.assertEqual(state.attention.shape, static_state.attention.shape)
예제 #10
0
    def build_decoder_cell(self):

        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        # To use BeamSearchDecoder, encoder_outputs, encoder_last_state, encoder_inputs_length
        # needs to be tiled so that: [batch_size, .., ..] -> [batch_size x beam_width, .., ..]
        if self.use_beamsearch_decode:
            print("use beamsearch decoding..")
            encoder_outputs = seq2seq.tile_batch(self.encoder_outputs,
                                                 multiplier=self.beam_width)
            encoder_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.beam_width),
                self.encoder_last_state)
            encoder_inputs_length = seq2seq.tile_batch(
                self.encoder_inputs_length, multiplier=self.beam_width)

        # Building attention mechanism: Default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        self.attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=self.hidden_units,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length,
        )
        # 'Luong' style attention: https://arxiv.org/abs/1508.04025
        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.hidden_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_length,
            )

        # Building decoder_cell
        self.decoder_cell_list = [
            self.build_single_cell() for i in range(self.depth)
        ]
        decoder_initial_state = encoder_last_state

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs

            # Essential when use_residual=True
            _input_layer = Dense(self.hidden_units,
                                 dtype=self.dtype,
                                 name='attn_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # Note: We implement Attention mechanism only on the top decoder layer
        self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=True,
            name='Attention_Wrapper')

        # To be compatible with AttentionWrapper, the encoder last state
        # of the top layer should be converted into the AttentionWrapperState form
        # We can easily do this by calling AttentionWrapper.zero_state

        # Also if beamsearch decoding is used, the batch_size argument in .zero_state
        # should be ${decoder_beam_width} times to the origianl batch_size
        batch_size = self.batch_size if not self.use_beamsearch_decode \
                     else self.batch_size * self.beam_width
        initial_state = [state for state in encoder_last_state]

        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)

        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
예제 #11
0
def _attention_decoder_wrapper(batch_size, num_units, memory, mutli_layer, dtype=dtypes.float32 ,\
                               attention_layer_size=None, cell_input_fn=None, attention_type='B',\
                               probability_fn=None, alignment_history=False, output_attention=True, \
                               initial_cell_state=None, normalization=False, sigmoid_noise=0.,
                               sigmoid_noise_seed=None, score_bias_init=0.):
    """
    A wrapper for rnn-decoder with attention mechanism

    the detail about params explanation can be found at :
        blog.csdn.net/qsczse943062710/article/details/79539005

    :param mutli_layer: a object returned by function _mutli_layer_rnn()

    :param attention_type, string
        'B' is for BahdanauAttention as described in:

          Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
          "Neural Machine Translation by Jointly Learning to Align and Translate."
          ICLR 2015. https://arxiv.org/abs/1409.0473

        'L' is for LuongAttention as described in:

            Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
            "Effective Approaches to Attention-based Neural Machine Translation."
            EMNLP 2015.  https://arxiv.org/abs/1508.04025

        MonotonicAttention is described in :

            Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
            "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
            ICML 2017.  https://arxiv.org/abs/1704.00784

        'BM' :  Monotonic attention mechanism with Bahadanau-style energy function

        'LM' :  Monotonic attention mechanism with Luong-style energy function


        or maybe something user defined in the future
        **warning** :

            if normalization is set True,
            then normalization will be applied to all types of attentions as described in:
                Tim Salimans, Diederik P. Kingma.
                "Weight Normalization: A Simple Reparameterization to Accelerate
                Training of Deep Neural Networks."
                https://arxiv.org/abs/1602.07868

    A example usage:
        att_wrapper, states = _attention_decoder_wrapper(*args)
        while decoding:
            output, states = att_wrapper(input, states)
            ...
            some processing on output
            ...
            input = processed_output
    """

    if attention_type == 'B':
        attention_mechanism = att_w.BahdanauAttention(
            num_units=num_units,
            memory=memory,
            probability_fn=probability_fn,
            normalize=normalization)
    elif attention_type == 'BM':
        attention_mechanism = att_w.BahdanauMonotonicAttention(
            num_units=num_units,
            memory=memory,
            normalize=normalization,
            sigmoid_noise=sigmoid_noise,
            sigmoid_noise_seed=sigmoid_noise_seed,
            score_bias_init=score_bias_init)
    elif attention_type == 'L':
        attention_mechanism = att_w.LuongAttention(
            num_units=num_units,
            memory=memory,
            probability_fn=probability_fn,
            scale=normalization)
    elif attention_type == 'LM':
        attention_mechanism = att_w.LuongMonotonicAttention(
            num_units=num_units,
            memory=memory,
            scale=normalization,
            sigmoid_noise=sigmoid_noise,
            sigmoid_noise_seed=sigmoid_noise_seed,
            score_bias_init=score_bias_init)
    else:
        raise 'Invalid attention type'

    att_wrapper = att_w.AttentionWrapper(
        cell=mutli_layer,
        attention_mechanism=attention_mechanism,
        attention_layer_size=attention_layer_size,
        cell_input_fn=cell_input_fn,
        alignment_history=alignment_history,
        output_attention=output_attention,
        initial_cell_state=initial_cell_state)
    init_states = att_wrapper.zero_state(batch_size=batch_size, dtype=dtype)
    return att_wrapper, init_states