def _build_decoder_cell(self):
        # no beam
        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs
            _input_layer = Dense(self.hidden_units, dtype=self.dtype, name="attn_input_feeding")
            return _input_layer(array_ops.concat([inputs, attention], -1))
        
        # attention mechanism 'luong'
        with tf.variable_scope('shared_attention_mechanism'):
            self.attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \
                                                                        memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)        
        # build decoder cell
        self.init_decoder_cell_list = [self._build_single_cell() for i in range(self.depth)]
        decoder_initial_state = encoder_last_state
        
        self.decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\
            cell = self.init_decoder_cell_list[-1], \
            attention_mechanism=self.attention_mechanism,\
            attention_layer_size=self.hidden_units,\
            cell_input_fn=attn_decoder_input_fn,\
            initial_cell_state=encoder_last_state[-1],\
            alignment_history=False)]
        batch_size = self.batch_size
        initial_state = [state for state in encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)
        
        # beam
        beam_encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width)
        beam_encoder_last_state = nest.map_structure(lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state)
        beam_encoder_inputs_length = seq2seq.tile_batch(self.encoder_inputs_length, multiplier=self.beam_width)

        with tf.variable_scope('shared_attention_mechanism', reuse=True):
            self.beam_attention_mechanism = attention_wrapper.LuongAttention(num_units=self.hidden_units, \
                                                                             memory=beam_encoder_outputs, \
                                                                             memory_sequence_length=beam_encoder_inputs_length)

        beam_decoder_initial_state = beam_encoder_last_state
        self.beam_decoder_cell_list = self.init_decoder_cell_list[:-1] + [attention_wrapper.AttentionWrapper(\
            cell = self.init_decoder_cell_list[-1], \
            attention_mechanism=self.beam_attention_mechanism,\
            attention_layer_size=self.hidden_units,\
            cell_input_fn=attn_decoder_input_fn,\
            initial_cell_state=beam_encoder_last_state[-1],\
            alignment_history=False)]
            
        beam_batch_size = self.batch_size * self.beam_width
        beam_initial_state = [state for state in beam_encoder_last_state]
        beam_initial_state[-1] = self.beam_decoder_cell_list[-1].zero_state(batch_size=beam_batch_size, dtype=self.dtype)
        beam_decoder_initial_state = tuple(beam_initial_state)
        
        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state, \
               MultiRNNCell(self.beam_decoder_cell_list), beam_decoder_initial_state
  def testLuongScaledDType(self):
    # Test case for GitHub issue 18099
    for dtype in [np.float16, np.float32, np.float64]:
      num_units = 128
      encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
      encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
      decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      batch_size = 64
      attention_mechanism = wrapper.LuongAttention(
          num_units=num_units,
          memory=encoder_outputs,
          memory_sequence_length=encoder_sequence_length,
          scale=True,
          dtype=dtype,
      )
      cell = rnn_cell.LSTMCell(num_units)
      cell = wrapper.AttentionWrapper(cell, attention_mechanism)

      helper = helper_py.TrainingHelper(decoder_inputs,
                                        decoder_sequence_length)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtype, batch_size=batch_size))

      final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual(final_outputs.rnn_output.dtype, dtype)
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
  def testCustomizedAttention(self):
    batch_size = 2
    max_time = 3
    num_units = 2
    memory = constant_op.constant([[[1., 1.], [2., 2.], [3., 3.]],
                                   [[4., 4.], [5., 5.], [6., 6.]]])
    memory_sequence_length = constant_op.constant([3, 2])
    attention_mechanism = wrapper.BahdanauAttention(num_units, memory,
                                                    memory_sequence_length)

    # Sets all returned values to be all ones.
    def _customized_attention(unused_attention_mechanism, unused_cell_output,
                              unused_attention_state, unused_attention_layer):
      """Customized attention.

      Returns:
        attention: `Tensor` of shape [batch_size, num_units], attention output.
        alignments: `Tensor` of shape [batch_size, max_time], sigma value for
          each input memory (prob. function of input keys).
        next_attention_state: A `Tensor` representing the next state for the
          attention.
      """
      attention = array_ops.ones([batch_size, num_units])
      alignments = array_ops.ones([batch_size, max_time])
      next_attention_state = alignments
      return attention, alignments, next_attention_state

    attention_cell = wrapper.AttentionWrapper(
        rnn_cell.LSTMCell(2),
        attention_mechanism,
        attention_layer_size=None,  # don't use attention layer.
        output_attention=False,
        alignment_history=(),
        attention_fn=_customized_attention,
        name='attention')
    self.assertEqual(num_units, attention_cell.output_size)

    initial_state = attention_cell.zero_state(
        batch_size=2, dtype=dtypes.float32)
    source_input_emb = array_ops.ones([2, 3, 2])
    source_input_length = constant_op.constant([3, 2])

    # 'state' is a tuple of
    # (cell_state, h, attention, alignments, alignment_history, attention_state)
    output, state = rnn.dynamic_rnn(
        attention_cell,
        inputs=source_input_emb,
        sequence_length=source_input_length,
        initial_state=initial_state,
        dtype=dtypes.float32)

    with self.session() as sess:
      sess.run(variables.global_variables_initializer())
      output_value, state_value = sess.run([output, state], feed_dict={})
      self.assertAllEqual(np.array([2, 3, 2]), output_value.shape)
      self.assertAllClose(np.array([[1., 1.], [1., 1.]]), state_value.attention)
      self.assertAllClose(
          np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.alignments)
      self.assertAllClose(
          np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.attention_state)
Beispiel #4
0
    def build_decoder_cell(self):

        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        if self.use_beamsearch_decode:
            print ("use beamsearch decoding..")
            encoder_outputs = seq2seq.tile_batch(
                self.encoder_outputs, multiplier=self.beam_width)
            encoder_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state)
            encoder_inputs_length = seq2seq.tile_batch(
                self.encoder_inputs_length, multiplier=self.beam_width)

        # Building attention mechanism: Default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        self.attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=self.hidden_units, memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length,) 
        # 'Luong' style attention: https://arxiv.org/abs/1508.04025
        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.hidden_units, memory=encoder_outputs, 
                memory_sequence_length=encoder_inputs_length,)
 
        # Building decoder_cell
        self.decoder_cell_list = [
            self.build_single_cell() for i in range(self.depth)]
        decoder_initial_state = encoder_last_state

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs

            # Essential when use_residual=True
            _input_layer = Dense(self.hidden_units, dtype=self.dtype,
                                 name='attn_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # Note: We implement Attention mechanism only on the top decoder layer
        self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=False,
            name='Attention_Wrapper')

        batch_size = self.batch_size if not self.use_beamsearch_decode \
                     else self.batch_size * self.beam_width
        initial_state = [state for state in encoder_last_state]

        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
          batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)

        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
Beispiel #5
0
    def testLuongScaledDType(self, dtype):
        # Test case for GitHub issue 18099
        encoder_outputs = self.encoder_outputs.astype(dtype)
        decoder_inputs = self.decoder_inputs.astype(dtype)
        attention_mechanism = wrapper.LuongAttentionV2(
            units=self.units,
            memory=encoder_outputs,
            memory_sequence_length=self.encoder_sequence_length,
            scale=True,
            dtype=dtype,
        )
        cell = keras.layers.LSTMCell(self.units,
                                     recurrent_activation="sigmoid")
        cell = wrapper.AttentionWrapper(cell, attention_mechanism)

        sampler = sampler_py.TrainingSampler()
        my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)

        final_outputs, final_state, _ = my_decoder(
            decoder_inputs,
            initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch),
            sequence_length=self.decoder_sequence_length)
        self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
        self.assertEqual(final_outputs.rnn_output.dtype, dtype)
        self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
Beispiel #6
0
    def build_dec_cell(self, hidden_size):
        enc_outputs = self.enc_outputs
        enc_last_state = self.enc_last_state
        enc_inputs_length = self.enc_inp_len

        if self.use_beam_search:
            self.logger.info("using beam search decoding")
            enc_outputs = seq2seq.tile_batch(self.enc_outputs,
                                             multiplier=self.p.beam_width)
            enc_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.p.beam_width),
                self.enc_last_state)
            enc_inputs_length = seq2seq.tile_batch(self.enc_inp_len,
                                                   self.p.beam_width)

        if self.p.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=hidden_size,
                memory=enc_outputs,
                memory_sequence_length=enc_inputs_length)
        else:
            self.attention_mechanism = attention_wrapper.BahdanauAttention(
                num_units=hidden_size,
                memory=enc_outputs,
                memory_sequence_length=enc_inputs_length)

        def attn_dec_input_fn(inputs, attention):
            if not self.p.attn_input_feeding:
                return inputs
            else:
                _input_layer = Dense(hidden_size,
                                     dtype=self.p.dtype,
                                     name='attn_input_feeding')
                return _input_layer(tf.concat([inputs, attention], -1))

        self.dec_cell_list = [
            self.build_single_cell(hidden_size) for _ in range(self.p.depth)
        ]

        if self.p.use_attn:
            self.dec_cell_list[-1] = attention_wrapper.AttentionWrapper(
                cell=self.dec_cell_list[-1],
                attention_mechanism=self.attention_mechanism,
                attention_layer_size=hidden_size,
                cell_input_fn=attn_dec_input_fn,
                initial_cell_state=enc_last_state[-1],
                alignment_history=False,
                name='attention_wrapper')

        batch_size = self.p.batch_size if not self.use_beam_search else self.p.batch_size * self.p.beam_width
        initial_state = [state for state in enc_last_state]
        if self.p.use_attn:
            initial_state[-1] = self.dec_cell_list[-1].zero_state(
                batch_size=batch_size, dtype=self.p.dtype)
        dec_initial_state = tuple(initial_state)

        return MultiRNNCell(self.dec_cell_list), dec_initial_state
    def create_decoder(self, encoded, inputs, speaker_embed, train=True):
        config = self.config
        attention_mech = wrapper.BahdanauAttention(
            config.attention_units,
            encoded,
            memory_sequence_length=inputs['text_length'])

        inner_cell = [GRUCell(config.decoder_units) for _ in range(3)]

        decoder_cell = OutputProjectionWrapper(
            InputProjectionWrapper(ResidualWrapper(MultiRNNCell(inner_cell)),
                                   config.decoder_units),
            config.mel_features * config.r)

        # feed in rth frame at each time step
        decoder_frame_input = \
            lambda inputs, attention: tf.concat(
                    [self.pre_net(tf.slice(inputs,
                        [0, (config.r - 1)*config.mel_features], [-1, -1]),
                        dropout=config.audio_dropout_prob,
                        train=train),
                    attention]
                , -1)

        cell = wrapper.AttentionWrapper(
            decoder_cell,
            attention_mech,
            attention_layer_size=config.attention_units,
            cell_input_fn=decoder_frame_input,
            alignment_history=True,
            output_attention=False)

        if train:
            if config.scheduled_sample:
                print("if train if config.scheduled_sample: %s" % str(
                    (inputs['mel'], inputs['speech_length'],
                     config.scheduled_sample)))
                decoder_helper = helper.ScheduledOutputTrainingHelper(
                    inputs['mel'], inputs['speech_length'],
                    config.scheduled_sample)
            else:
                decoder_helper = helper.TrainingHelper(inputs['mel'],
                                                       inputs['speech_length'])
        else:
            decoder_helper = ops.InferenceHelper(
                tf.shape(inputs['text'])[0], config.mel_features * config.r)

        initial_state = cell.zero_state(dtype=tf.float32,
                                        batch_size=tf.shape(inputs['text'])[0])

        #if speaker_embed is not None:
        #initial_state.attention = tf.layers.dense(speaker_embed, config.attention_units)

        dec = basic_decoder.BasicDecoder(cell, decoder_helper, initial_state)

        return dec
Beispiel #8
0
    def _attentive_bidirectional_cudnn_LSTM(self,
                                            inputs,
                                            input_size,
                                            lengths,
                                            attention_mechanism=False,
                                            num_units=256,
                                            num_layers=1,
                                            dp_input_keep_prob=1.0,
                                            dp_output_keep_prob=1.0):
        with tf.variable_scope('fw'):
            cell_fw = create_cudnn_LSTM_cell(
                num_units=num_units,
                input_size=input_size,
                num_layers=num_layers,
                dp_input_keep_prob=dp_input_keep_prob,
                dp_output_keep_prob=dp_output_keep_prob)

        with tf.variable_scope('bw'):
            cell_bw = create_cudnn_LSTM_cell(
                num_units=num_units,
                input_size=input_size,
                num_layers=num_layers,
                dp_input_keep_prob=dp_input_keep_prob,
                dp_output_keep_prob=dp_output_keep_prob)

        if attention_mechanism:
            cell_fw = attention_wrapper.AttentionWrapper(
                cell=cell_fw,
                attention_mechanism=attention_mechanism,
                output_attention=False)
            cell_bw = attention_wrapper.AttentionWrapper(
                cell=cell_bw,
                attention_mechanism=attention_mechanism,
                output_attention=False)

        return bidirectional_dynamic_rnn(cell_fw=cell_fw,
                                         cell_bw=cell_bw,
                                         inputs=inputs,
                                         sequence_length=lengths,
                                         dtype=getdtype())
Beispiel #9
0
    def build_decoder(self):
        """
        decoder
        :return:
        """
        print('build decoder with attention...')
        with tf.variable_scope('decoder'):
            self.decoder_embeddings = tf.Variable(
                tf.random_uniform([self.vocab_size, self.embedding_size]))

            # 2.1 add attention
            def build_decoder_cell():
                decoder_cell = tf.contrib.rnn.LSTMCell(
                    self.hidden_size,
                    initializer=tf.random_uniform_initializer(-0.1,
                                                              0.1,
                                                              seed=2))
                return decoder_cell

            attention_states = self.encoder_outputs
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                num_units=self.hidden_size,
                memory=attention_states,
                memory_sequence_length=self.source_sequence_length)

            decoder_cells_list = [
                build_decoder_cell() for _ in range(self.num_layers)
            ]
            decoder_cells_list[-1] = attention_wrapper.AttentionWrapper(
                cell=decoder_cells_list[-1],
                attention_mechanism=attention_mechanism,
                attention_layer_size=self.hidden_size)

            self.decoder_cells = tf.contrib.rnn.MultiRNNCell(
                decoder_cells_list)

            initial_state = [state for state in self.encoder_states]
            initial_state[-1] = decoder_cells_list[-1].zero_state(
                batch_size=self.batch_size, dtype=tf.float32)
            self.decoder_initial_state = tuple(initial_state)

            # 全连接
            self.output_layer = Dense(
                self.vocab_size,
                kernel_initializer=tf.truncated_normal_initializer(mean=0.0,
                                                                   stddev=0.1))
            if self.mode == 'train':
                self.interfer()
            elif self.mode == 'decode':
                self.decode()
    def _attentive_bidirectional_rnn(self,
                                     inputs,
                                     lengths,
                                     num_units,
                                     attention_mechanism=False,
                                     cell_type='gru',
                                     num_layers=1,
                                     dp_input_keep_prob=1.0,
                                     dp_output_keep_prob=1.0):

        cell_fw = create_rnn_cell(cell_type=cell_type,
                                  num_units=num_units,
                                  num_layers=num_layers,
                                  dp_input_keep_prob=dp_input_keep_prob,
                                  dp_output_keep_prob=dp_output_keep_prob)

        cell_bw = create_rnn_cell(cell_type=cell_type,
                                  num_units=num_units,
                                  num_layers=num_layers,
                                  dp_input_keep_prob=dp_input_keep_prob,
                                  dp_output_keep_prob=dp_output_keep_prob)

        if attention_mechanism:
            cell_fw = attention_wrapper.AttentionWrapper(
                cell=cell_fw,
                attention_mechanism=attention_mechanism,
                output_attention=False)
            cell_bw = attention_wrapper.AttentionWrapper(
                cell=cell_bw,
                attention_mechanism=attention_mechanism,
                output_attention=False)

        return bidirectional_dynamic_rnn(cell_fw=cell_fw,
                                         cell_bw=cell_bw,
                                         inputs=inputs,
                                         sequence_length=lengths,
                                         dtype=getdtype())
Beispiel #11
0
    def build_decoder_cell(self):
        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length
        # building attention mechanism: default Bahdanau
        # 'Bahdanau': https://arxiv.org/abs/1409.0473
        self.attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=self.hidden_size,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length)
        # 'Luong': https://arxiv.org/abs/1508.04025
        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.hidden_size,
                memory=self.encoder_outputs,
                memory_sequence_length=self.encoder_inputs_length)

        # building decoder_cell
        self.decoder_cell_list = [
            self.build_single_cell() for _ in range(self.layer_num)
        ]

        def att_decoder_input_fn(inputs, attention):
            if not self.use_att_decoding:
                return inputs

            _input_layer = Dense(self.hidden_size,
                                 dtype=self.dtype,
                                 name='att_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], axis=-1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # implement attention mechanism only on the top of decoder layer
        self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_size,
            cell_input_fn=att_decoder_input_fn,
            initial_cell_state=encoder_last_state[
                -1],  # last hidden state of last encode layer
            alignment_history=False,
            name='Attention_Wrapper')
        initial_state = [state for state in encoder_last_state]
        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=self.batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)
        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
Beispiel #12
0
    def build_decoder_cell(self, num_units, num_layers, keep_prob):
        encoder_outputs = tf.concat(self.encoder_outputs, axis=-1)

        encoder_final_state = []
        encoder_fw_fs, encoder_bw_fs = self.encoder_fs
        for i in range(num_layers):
            final_state_c = tf.concat((encoder_fw_fs[i].c, encoder_bw_fs[i].c),
                                      axis=1)
            final_state_h = tf.concat((encoder_fw_fs[i].h, encoder_bw_fs[i].h),
                                      axis=1)
            encoder_final_state.append(
                LSTMStateTuple(c=final_state_c, h=final_state_h))
        encoder_fs = tuple(encoder_final_state)

        # build decoder cell
        decoder_cells = [
            self.make_rnn_cell(num_units, keep_prob) for _ in range(num_layers)
        ]
        attention_cell = decoder_cells.pop()

        # use Bahdanua attention to all cell layers.
        self.attention_machenism = attention_wrapper.BahdanauAttention(
            num_units=num_units,
            memory=encoder_outputs,
            normalize=False,
            memory_sequence_length=self.encoder_length)

        attention_cell = attention_wrapper.AttentionWrapper(
            attention_cell,
            self.attention_machenism,
            attention_layer_size=None,
            initial_cell_state=None,
            output_attention=False,
            alignment_history=False,
        )
        decoder_cells.append(attention_cell)
        decoder_cells = tf.nn.rnn_cell.MultiRNNCell(decoder_cells)
        batch = self.batch
        decoder_init_state = tuple(
            zs.clone(cell_state=es) if isinstance(
                zs, tf.contrib.seq2seq.AttentionWrapperState) else es
            for zs, es in zip(
                decoder_cells.zero_state(batch, dtype=tf.float32), encoder_fs))

        # why the last layers' zero state different with
        # init_state = [state for state in encoder_fs]
        return decoder_cells, decoder_init_state
    def build_decoder_cell(self):

        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        # Building Attention Mechanism: Default Bahdanau
        self.attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=self.decoder_hidden_units,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length)

        self.decoder_cell_list = [
            self.build_single_cell() for i in range(self.depth)
        ]
        decoder_initial_state = encoder_last_state

        def attn_decoder_input_fn(inputs, attention):

            if not self.attn_input_feeding:
                return inputs

            _input_layer = Dense(self.decoder_hidden_units,
                                 dtype=tf.float32,
                                 name='attn_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.decoder_hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=False,
            name='Attention_Wrapper')

        batch_size = self.batch_size
        initial_state = [state for state in encoder_last_state]

        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=batch_size, dtype=tf.float32)
        decoder_initial_state = tuple(initial_state)

        return tf.contrib.rnn.MultiRNNCell(
            self.decoder_cell_list), decoder_initial_state
    def build_decoder_cell(self):
        self.decoder_cell_list = \
           [self.build_single_cell() for i in range(self.para.num_layers)]

        if self.para.mode == 'train':
            encoder_outputs = self.encoder_outputs
            encoder_inputs_len = self.encoder_inputs_len
            encoder_states = self.encoder_states
            batch_size = self.para.batch_size
        else:
            encoder_outputs = seq2seq.tile_batch(
                self.encoder_outputs, multiplier=self.para.beam_width)
            encoder_inputs_len = seq2seq.tile_batch(
                self.encoder_inputs_len, multiplier=self.para.beam_width)
            encoder_states = seq2seq.tile_batch(
                self.encoder_states, multiplier=self.para.beam_width)
            batch_size = self.para.batch_size * self.para.beam_width

        if self.para.attention_mode == 'luong':
            # scaled luong: recommended by authors of NMT
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.para.num_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_len,
                scale=True)
            output_attention = True
        else:
            self.attention_mechanism = attention_wrapper.BahdanauAttention(
                num_units=self.para.num_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_len)
            output_attention = False

        cell = tf.contrib.rnn.MultiRNNCell(self.decoder_cell_list)
        cell = attention_wrapper.AttentionWrapper(
            cell=cell,
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.para.num_units,
            name='attention')
        decoder_initial_state = cell.zero_state(
            batch_size, self.dtype).clone(cell_state=encoder_states)

        return cell, decoder_initial_state
Beispiel #15
0
  def build_attention_decoder_cell(self):
    encoder_outputs = self.encoder_outputs
    encoder_last_state = self.encoder_last_state
    encoder_inputs_length = self.encoder_inputs_length

    self.attention_mechanism = attention_wrapper.BahdanauAttention(
      num_units=self.hidden_units, memory=encoder_outputs,
      memory_sequence_length=encoder_inputs_length, )

    if self.attention_type.lower()=='luong':
      self.attention_mechanism = attention_wrapper.LuongAttention(
        num_units=self.hidden_units, memory=encoder_outputs,
        memory_sequence_length=encoder_inputs_length, )

    # # Building decoder_cell
    self.decoder_cell_list = [self.build_single_cell() for i in range(self.num_layers)]

    def attn_decoder_input_fn(inputs, attention):
      if not self.attn_input_feeding:
        return inputs

      # Essential when use_residual=True
      _input_layer = tf.layers.dense(tf.concat([inputs, attention], axis=-1), self.hidden_units,
                                     name='attn_input_feeding')
      return _input_layer

    self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
      cell=self.decoder_cell_list[-1],
      attention_mechanism=self.attention_mechanism,
      attention_layer_size=self.hidden_units,
      cell_input_fn=attn_decoder_input_fn,
      initial_cell_state=encoder_last_state[-1],
      alignment_history=False,
      name='Attention_Wrapper'
      )
    batch_size = self.config['batch_size']
    initial_state = [state for state in encoder_last_state]

    initial_state[-1] = self.decoder_cell_list[-1].zero_state(
      batch_size=batch_size, dtype=self.dtype)
    decoder_initial_state = tuple(initial_state)
    return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
Beispiel #16
0
  def testAttentionWrapperStateShapePropgation(self):
    batch_size = 5
    max_time = 5
    num_units = 5

    memory = random_ops.random_uniform(
        [batch_size, max_time, num_units], seed=1)
    mechanism = wrapper.LuongAttention(num_units, memory)
    cell = wrapper.AttentionWrapper(rnn_cell.LSTMCell(num_units), mechanism)

    # Create zero state with static batch size.
    static_state = cell.zero_state(batch_size, dtypes.float32)
    # Create zero state without static batch size.
    state = cell.zero_state(array_ops.shape(memory)[0], dtypes.float32)

    state = static_state.clone(
        cell_state=state.cell_state, attention=state.attention)

    self.assertEqual(state.cell_state.c.shape, static_state.cell_state.c.shape)
    self.assertEqual(state.cell_state.h.shape, static_state.cell_state.h.shape)
    self.assertEqual(state.attention.shape, static_state.attention.shape)
Beispiel #17
0
  def _testBahdanauNormalizedDType(self, dtype):
    encoder_outputs = self.encoder_outputs.astype(dtype)
    decoder_inputs = self.decoder_inputs.astype(dtype)
    attention_mechanism = wrapper.BahdanauAttentionV2(
        units=self.units,
        memory=encoder_outputs,
        memory_sequence_length=self.encoder_sequence_length,
        normalize=True,
        dtype=dtype)
    cell = rnn_cell.LSTMCell(self.units)
    cell = wrapper.AttentionWrapper(cell, attention_mechanism)

    sampler = sampler_py.TrainingSampler()
    my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)

    final_outputs, final_state, _ = my_decoder(
        decoder_inputs,
        initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch),
        sequence_length=self.decoder_sequence_length)
    self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
    self.assertEqual(final_outputs.rnn_output.dtype, dtype)
    self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
    self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
Beispiel #18
0
    def _testWithMaybeMultiAttention(self,
                                     is_multi,
                                     create_attention_mechanisms,
                                     expected_final_output,
                                     expected_final_state,
                                     attention_mechanism_depths,
                                     alignment_history=False,
                                     expected_final_alignment_history=None,
                                     attention_layer_sizes=None,
                                     attention_layers=None,
                                     create_query_layer=False,
                                     create_memory_layer=True,
                                     create_attention_kwargs=None):
        # Allow is_multi to be True with a single mechanism to enable test for
        # passing in a single mechanism in a list.
        assert len(create_attention_mechanisms) == 1 or is_multi
        encoder_sequence_length = [3, 2, 3, 1, 1]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        create_attention_kwargs = create_attention_kwargs or {}

        if attention_layer_sizes is not None:
            # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
            attention_depth = sum(
                attention_layer_size or encoder_output_depth
                for attention_layer_size in attention_layer_sizes)
        elif attention_layers is not None:
            # Compute sum of attention_layers output depth.
            attention_depth = sum(
                attention_layer.compute_output_shape(
                    [batch_size, cell_depth +
                     encoder_output_depth]).dims[-1].value
                for attention_layer in attention_layers)
        else:
            attention_depth = encoder_output_depth * len(
                create_attention_mechanisms)

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanisms = []
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths):
            # Create a memory layer with deterministic initializer to avoid randomness
            # in the test between graph and eager.
            if create_query_layer:
                create_attention_kwargs["query_layer"] = keras.layers.Dense(
                    depth, kernel_initializer="ones", use_bias=False)
            if create_memory_layer:
                create_attention_kwargs["memory_layer"] = keras.layers.Dense(
                    depth, kernel_initializer="ones", use_bias=False)

            attention_mechanisms.append(
                creator(units=depth,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_sequence_length,
                        **create_attention_kwargs))

        with self.cached_session(use_gpu=True):
            attention_layer_size = attention_layer_sizes
            attention_layer = attention_layers
            if not is_multi:
                if attention_layer_size is not None:
                    attention_layer_size = attention_layer_size[0]
                if attention_layer is not None:
                    attention_layer = attention_layer[0]
            cell = keras.layers.LSTMCell(cell_depth,
                                         recurrent_activation="sigmoid",
                                         kernel_initializer="ones",
                                         recurrent_initializer="ones")
            cell = wrapper.AttentionWrapper(
                cell,
                attention_mechanisms if is_multi else attention_mechanisms[0],
                attention_layer_size=attention_layer_size,
                alignment_history=alignment_history,
                attention_layer=attention_layer)
            if cell._attention_layers is not None:
                for layer in cell._attention_layers:
                    if getattr(layer, "kernel_initializer") is None:
                        layer.kernel_initializer = initializers.glorot_uniform(
                            seed=1337)

            sampler = sampler_py.TrainingSampler()
            my_decoder = basic_decoder.BasicDecoderV2(cell=cell,
                                                      sampler=sampler)
            initial_state = cell.get_initial_state(dtype=dtypes.float32,
                                                   batch_size=batch_size)
            final_outputs, final_state, _ = my_decoder(
                decoder_inputs,
                initial_state=initial_state,
                sequence_length=decoder_sequence_length)

            self.assertIsInstance(final_outputs,
                                  basic_decoder.BasicDecoderOutput)
            self.assertIsInstance(final_state, wrapper.AttentionWrapperState)

            expected_time = (expected_final_state.time
                             if context.executing_eagerly() else None)
            self.assertEqual(
                (batch_size, expected_time, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, expected_time),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state[0].get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state[1].get_shape().as_list()))

            if alignment_history:
                if is_multi:
                    state_alignment_history = []
                    for history_array in final_state.alignment_history:
                        history = history_array.stack()
                        self.assertEqual(
                            (expected_time, batch_size, encoder_max_time),
                            tuple(history.get_shape().as_list()))
                        state_alignment_history.append(history)
                    state_alignment_history = tuple(state_alignment_history)
                else:
                    state_alignment_history = final_state.alignment_history.stack(
                    )
                    self.assertEqual(
                        (expected_time, batch_size, encoder_max_time),
                        tuple(state_alignment_history.get_shape().as_list()))
                nest.assert_same_structure(
                    cell.state_size, cell.zero_state(batch_size,
                                                     dtypes.float32))
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
            else:
                state_alignment_history = ()

            self.evaluate(variables.global_variables_initializer())
            eval_result = self.evaluate({
                "final_outputs":
                final_outputs,
                "final_state":
                final_state,
                "state_alignment_history":
                state_alignment_history,
            })

            final_output_info = nest.map_structure(
                get_result_summary, eval_result["final_outputs"])
            final_state_info = nest.map_structure(get_result_summary,
                                                  eval_result["final_state"])
            print("final_output_info: ", final_output_info)
            print("final_state_info: ", final_state_info)

            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_output, final_output_info)
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_state, final_state_info)
            if alignment_history:  # by default, the wrapper emits attention as output
                final_alignment_history_info = nest.map_structure(
                    get_result_summary, eval_result["state_alignment_history"])
                print("final_alignment_history_info: ",
                      final_alignment_history_info)
                nest.map_structure(
                    self.assertAllCloseOrEqual,
                    # outputs are batch major but the stacked TensorArray is time major
                    expected_final_alignment_history,
                    final_alignment_history_info)
    def _testWithAttention(self,
                           create_attention_mechanism,
                           expected_final_output,
                           expected_final_state,
                           attention_mechanism_depth=3,
                           alignment_history=False,
                           expected_final_alignment_history=None,
                           attention_layer_size=6,
                           name=''):
        encoder_sequence_length = [3, 2, 3, 1, 1]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9

        if attention_layer_size is not None:
            attention_depth = attention_layer_size
        else:
            attention_depth = encoder_output_depth

        decoder_inputs = array_ops.placeholder_with_default(
            np.random.randn(batch_size, decoder_max_time,
                            input_depth).astype(np.float32),
            shape=(None, None, input_depth))
        encoder_outputs = array_ops.placeholder_with_default(
            np.random.randn(batch_size, encoder_max_time,
                            encoder_output_depth).astype(np.float32),
            shape=(None, None, encoder_output_depth))

        attention_mechanism = create_attention_mechanism(
            num_units=attention_mechanism_depth,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)

        with self.test_session(use_gpu=True) as sess:
            with vs.variable_scope(
                    'root',
                    initializer=init_ops.random_normal_initializer(stddev=0.01,
                                                                   seed=3)):
                cell = rnn_cell.LSTMCell(cell_depth)
                cell = wrapper.AttentionWrapper(
                    cell,
                    attention_mechanism,
                    attention_layer_size=attention_layer_size,
                    alignment_history=alignment_history)
                helper = helper_py.TrainingHelper(decoder_inputs,
                                                  decoder_sequence_length)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))

                final_outputs, final_state, _ = decoder.dynamic_decode(
                    my_decoder)

            self.assertTrue(
                isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
            self.assertTrue(
                isinstance(final_state, wrapper.AttentionWrapperState))
            self.assertTrue(
                isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))

            self.assertEqual(
                (batch_size, None, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, None),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.c.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.h.get_shape().as_list()))

            if alignment_history:
                state_alignment_history = final_state.alignment_history.stack()
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
                self.assertEqual(
                    (None, batch_size, None),
                    tuple(state_alignment_history.get_shape().as_list()))
            else:
                state_alignment_history = ()

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                'final_outputs':
                final_outputs,
                'final_state':
                final_state,
                'state_alignment_history':
                state_alignment_history,
            })

            final_output_info = nest.map_structure(
                get_result_summary, sess_results['final_outputs'])
            final_state_info = nest.map_structure(get_result_summary,
                                                  sess_results['final_state'])
            print(name)
            print('Copy/paste:\nexpected_final_output = %s' %
                  str(final_output_info))
            print('expected_final_state = %s' % str(final_state_info))
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_output, final_output_info)
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_state, final_state_info)
            if alignment_history:  # by default, the wrapper emits attention as output
                final_alignment_history_info = nest.map_structure(
                    get_result_summary,
                    sess_results['state_alignment_history'])
                print('expected_final_alignment_history = %s' %
                      str(final_alignment_history_info))
                nest.map_structure(
                    self.assertAllCloseOrEqual,
                    # outputs are batch major but the stacked TensorArray is time major
                    expected_final_alignment_history,
                    final_alignment_history_info)
Beispiel #20
0
    def build_decoder_cell(self):

        encoder_outputs = self.encoder_outputs
        encoder_last_state = self.encoder_last_state
        encoder_inputs_length = self.encoder_inputs_length

        # To use BeamSearchDecoder, encoder_outputs, encoder_last_state, encoder_inputs_length
        # needs to be tiled so that: [batch_size, .., ..] -> [batch_size x beam_width, .., ..]
        if self.use_beamsearch_decode:
            print("use beamsearch decoding..")
            encoder_outputs = seq2seq.tile_batch(self.encoder_outputs,
                                                 multiplier=self.beam_width)
            encoder_last_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(s, self.beam_width),
                self.encoder_last_state)
            encoder_inputs_length = seq2seq.tile_batch(
                self.encoder_inputs_length, multiplier=self.beam_width)

        # Building attention mechanism: Default Bahdanau
        # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
        self.attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=self.hidden_units,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length,
        )
        # 'Luong' style attention: https://arxiv.org/abs/1508.04025
        if self.attention_type.lower() == 'luong':
            self.attention_mechanism = attention_wrapper.LuongAttention(
                num_units=self.hidden_units,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_length,
            )

        # Building decoder_cell
        self.decoder_cell_list = [
            self.build_single_cell() for i in range(self.depth)
        ]
        decoder_initial_state = encoder_last_state

        def attn_decoder_input_fn(inputs, attention):
            if not self.attn_input_feeding:
                return inputs

            # Essential when use_residual=True
            _input_layer = Dense(self.hidden_units,
                                 dtype=self.dtype,
                                 name='attn_input_feeding')
            return _input_layer(array_ops.concat([inputs, attention], -1))

        # AttentionWrapper wraps RNNCell with the attention_mechanism
        # Note: We implement Attention mechanism only on the top decoder layer
        self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper(
            cell=self.decoder_cell_list[-1],
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=self.hidden_units,
            cell_input_fn=attn_decoder_input_fn,
            initial_cell_state=encoder_last_state[-1],
            alignment_history=True,
            name='Attention_Wrapper')

        # To be compatible with AttentionWrapper, the encoder last state
        # of the top layer should be converted into the AttentionWrapperState form
        # We can easily do this by calling AttentionWrapper.zero_state

        # Also if beamsearch decoding is used, the batch_size argument in .zero_state
        # should be ${decoder_beam_width} times to the origianl batch_size
        batch_size = self.batch_size if not self.use_beamsearch_decode \
                     else self.batch_size * self.beam_width
        initial_state = [state for state in encoder_last_state]

        initial_state[-1] = self.decoder_cell_list[-1].zero_state(
            batch_size=batch_size, dtype=self.dtype)
        decoder_initial_state = tuple(initial_state)

        return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
    def _build_output_layer_context(self, ques_outputs, ques_sequence_length,
                                    ctx_outputs, ctx_sequence_length):
        with tf.variable_scope('Output'):
            attention_depth_ques = self._model_params[
                'output_attention_layer_size_ques']
            if self._attention_type_output_layer == 'Luong':
                attention_depth_ques = self._model_params['ques_param_size']
            attention_mechanism_ques = self._build_attention(
                num_units=attention_depth_ques,
                memory=ques_outputs,
                memory_sequence_length=ques_sequence_length,
                attention_type=self._attention_type_output_layer)

            self.V_ques = tf.get_variable(
                name='V_ques',
                shape=[1, self._model_params['ques_param_size']],
                dtype=getdtype())
            alignments_ques = attention_mechanism_ques(
                tf.tile(self.V_ques, [self._model_params['batch_size'], 1]),
                previous_alignments=None)
            expanded_alignments_ques = tf.expand_dims(alignments_ques, 1)
            context = tf.matmul(expanded_alignments_ques, ques_outputs)
            context = tf.squeeze(context, [1])

            attention_depth_ans = self._model_params[
                'output_attention_layer_size_ans']
            if self._attention_type_output_layer == 'Luong':
                attention_depth_ans = self.embedded_dim
            attention_mechanism_ans = self._build_attention(
                num_units=attention_depth_ans,
                memory=ctx_outputs,
                memory_sequence_length=ctx_sequence_length,
                attention_type=self._attention_type_output_layer)

            output_rnn_cell = create_rnn_cell(
                cell_type=self._model_params['output_cell_type'],
                num_units=self.embedded_dim,
                num_layers=self._model_params['output_layers'],
                dp_input_keep_prob=self.
                _model_params['output_dp_input_keep_prob'],
                dp_output_keep_prob=self.
                _model_params['output_dp_output_keep_prob'])

            # context as initial state
            if self._model_params['output_cell_type'] == 'gru':
                if self._model_params['output_layers'] == 1:
                    initial_state = context
                else:
                    initial_state = tuple(
                        context
                        for _ in range(self._model_params['output_layers']))
            elif self._model_params['output_cell_type'] == 'lstm':
                if self._model_params['output_layers'] == 1:
                    initial_state = rnn_cell_impl.LSTMStateTuple(
                        tf.zeros_like(context, dtype=getdtype()), context)
                else:
                    initial_state = tuple(
                        rnn_cell_impl.LSTMStateTuple(
                            tf.zeros_like(context, dtype=getdtype()), context)
                        for _ in range(self._model_params['output_layers']))

            attentive_output_cell = attention_wrapper.AttentionWrapper(
                cell=output_rnn_cell,
                attention_mechanism=attention_mechanism_ans,
                alignment_history=True,
                cell_input_fn=lambda _, attention: attention,
                initial_cell_state=initial_state,
                output_attention=False)

            final_outputs, final_state = tf.nn.static_rnn(
                cell=attentive_output_cell,
                inputs=[
                    tf.zeros([self._model_params['batch_size'], 1]),
                    tf.zeros([self._model_params['batch_size'], 1])
                ],
                dtype=getdtype())

            alignment_history = final_state.alignment_history
            return alignment_history
Beispiel #22
0
def _attention_decoder_wrapper(batch_size, num_units, memory, mutli_layer, dtype=dtypes.float32 ,\
                               attention_layer_size=None, cell_input_fn=None, attention_type='B',\
                               probability_fn=None, alignment_history=False, output_attention=True, \
                               initial_cell_state=None, normalization=False, sigmoid_noise=0.,
                               sigmoid_noise_seed=None, score_bias_init=0.):
    """
    A wrapper for rnn-decoder with attention mechanism

    the detail about params explanation can be found at :
        blog.csdn.net/qsczse943062710/article/details/79539005

    :param mutli_layer: a object returned by function _mutli_layer_rnn()

    :param attention_type, string
        'B' is for BahdanauAttention as described in:

          Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
          "Neural Machine Translation by Jointly Learning to Align and Translate."
          ICLR 2015. https://arxiv.org/abs/1409.0473

        'L' is for LuongAttention as described in:

            Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
            "Effective Approaches to Attention-based Neural Machine Translation."
            EMNLP 2015.  https://arxiv.org/abs/1508.04025

        MonotonicAttention is described in :

            Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
            "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
            ICML 2017.  https://arxiv.org/abs/1704.00784

        'BM' :  Monotonic attention mechanism with Bahadanau-style energy function

        'LM' :  Monotonic attention mechanism with Luong-style energy function


        or maybe something user defined in the future
        **warning** :

            if normalization is set True,
            then normalization will be applied to all types of attentions as described in:
                Tim Salimans, Diederik P. Kingma.
                "Weight Normalization: A Simple Reparameterization to Accelerate
                Training of Deep Neural Networks."
                https://arxiv.org/abs/1602.07868

    A example usage:
        att_wrapper, states = _attention_decoder_wrapper(*args)
        while decoding:
            output, states = att_wrapper(input, states)
            ...
            some processing on output
            ...
            input = processed_output
    """

    if attention_type == 'B':
        attention_mechanism = att_w.BahdanauAttention(
            num_units=num_units,
            memory=memory,
            probability_fn=probability_fn,
            normalize=normalization)
    elif attention_type == 'BM':
        attention_mechanism = att_w.BahdanauMonotonicAttention(
            num_units=num_units,
            memory=memory,
            normalize=normalization,
            sigmoid_noise=sigmoid_noise,
            sigmoid_noise_seed=sigmoid_noise_seed,
            score_bias_init=score_bias_init)
    elif attention_type == 'L':
        attention_mechanism = att_w.LuongAttention(
            num_units=num_units,
            memory=memory,
            probability_fn=probability_fn,
            scale=normalization)
    elif attention_type == 'LM':
        attention_mechanism = att_w.LuongMonotonicAttention(
            num_units=num_units,
            memory=memory,
            scale=normalization,
            sigmoid_noise=sigmoid_noise,
            sigmoid_noise_seed=sigmoid_noise_seed,
            score_bias_init=score_bias_init)
    else:
        raise 'Invalid attention type'

    att_wrapper = att_w.AttentionWrapper(
        cell=mutli_layer,
        attention_mechanism=attention_mechanism,
        attention_layer_size=attention_layer_size,
        cell_input_fn=cell_input_fn,
        alignment_history=alignment_history,
        output_attention=output_attention,
        initial_cell_state=initial_cell_state)
    init_states = att_wrapper.zero_state(batch_size=batch_size, dtype=dtype)
    return att_wrapper, init_states
    def _testWithAttention(self,
                           create_attention_mechanism,
                           expected_final_output,
                           expected_final_state,
                           attention_mechanism_depth=3,
                           alignment_history=False,
                           expected_final_alignment_history=None,
                           name=""):
        encoder_sequence_length = [3, 2, 3, 1, 0]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        attention_depth = 6

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanism = create_attention_mechanism(
            num_units=attention_mechanism_depth,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)

        with self.test_session(use_gpu=True) as sess:
            with vs.variable_scope(
                    "root",
                    initializer=init_ops.random_normal_initializer(stddev=0.01,
                                                                   seed=3)):
                cell = core_rnn_cell.LSTMCell(cell_depth)
                cell = wrapper.AttentionWrapper(
                    cell,
                    attention_mechanism,
                    attention_size=attention_depth,
                    alignment_history=alignment_history)
                helper = helper_py.TrainingHelper(decoder_inputs,
                                                  decoder_sequence_length)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))

                final_outputs, final_state = decoder.dynamic_decode(my_decoder)

            self.assertTrue(
                isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
            self.assertTrue(
                isinstance(final_state, wrapper.AttentionWrapperState))
            self.assertTrue(
                isinstance(final_state.cell_state,
                           core_rnn_cell.LSTMStateTuple))

            self.assertEqual(
                (batch_size, None, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, None),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.c.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.h.get_shape().as_list()))

            if alignment_history:
                state_alignment_history = final_state.alignment_history.stack()
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
                self.assertEqual(
                    (None, batch_size, encoder_max_time),
                    tuple(state_alignment_history.get_shape().as_list()))
            else:
                state_alignment_history = ()

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_outputs":
                final_outputs,
                "final_state":
                final_state,
                "state_alignment_history":
                state_alignment_history,
            })

            print("Copy/paste (%s)\nexpected_final_output = " % name,
                  sess_results["final_outputs"])
            sys.stdout.flush()
            print("Copy/paste (%s)\nexpected_final_state = " % name,
                  sess_results["final_state"])
            sys.stdout.flush()
            print(
                "Copy/paste (%s)\nexpected_final_alignment_history = " % name,
                sess_results["state_alignment_history"])
            sys.stdout.flush()
            nest.map_structure(self.assertAllClose, expected_final_output,
                               sess_results["final_outputs"])
            nest.map_structure(self.assertAllClose, expected_final_state,
                               sess_results["final_state"])
            if alignment_history:  # by default, the wrapper emits attention as output
                self.assertAllClose(
                    # outputs are batch major but the stacked TensorArray is time major
                    sess_results["state_alignment_history"],
                    expected_final_alignment_history)
  def _build_decoder(self,
                     encoder_outputs,
                     enc_src_lengths,
                     tgt_inputs = None,
                     tgt_lengths = None,
                     GO_SYMBOL = 1,
                     END_SYMBOL = 2,
                     out_layer_activation = None):
    """
    Builds decoder part of the graph, for training and inference
    TODO: add param tensor shapes
    :param encoder_outputs:
    :param enc_src_lengths:
    :param tgt_inputs:
    :param tgt_lengths:
    :param GO_SYMBOL:
    :param END_SYMBOL:
    :param out_layer_activation:
    :return:
    """
    with tf.variable_scope("Decoder"):
      tgt_vocab_size = self.model_params['tgt_vocab_size']
      tgt_emb_size = self.model_params['tgt_emb_size']
      self._tgt_w = tf.get_variable(name='W_tgt_embedding',
                                    shape=[tgt_vocab_size, tgt_emb_size], dtype=getdtype())
      batch_size = self.model_params['batch_size']

      decoder_cell = create_rnn_cell(cell_type=self.model_params['decoder_cell_type'],
                                     cell_params={"num_units": self.model_params['decoder_cell_units']},
                                     num_layers=self.model_params['decoder_layers'],
                                     dp_input_keep_prob=self.model_params['decoder_dp_input_keep_prob'] if self._mode == "train" else 1.0,
                                     dp_output_keep_prob=self.model_params['decoder_dp_output_keep_prob'] if self._mode == "train" else 1.0,
                                     residual_connections=self.model_params['decoder_use_skip_connections'])

      output_layer = layers_core.Dense(tgt_vocab_size, use_bias=False,
                                       activation = out_layer_activation)

      def attn_decoder_custom_fn(inputs, attention):
          # to make shapes equal for skip connections
          if self.model_params['decoder_use_skip_connections']:
             input_layer = layers_core.Dense(self.model_params['decoder_cell_units'], dtype=getdtype())
             return input_layer(tf.concat([inputs, attention], -1))
          else:
            return tf.concat([inputs, attention], -1)

      if self.mode == "infer":
        if self._decoder_type == "beam_search":
          self._length_penalty_weight = 1.0 if "length_penalty" not in self.model_params else self.model_params[
            "length_penalty"]
          # beam_width of 1 should be same as argmax decoder
          self._beam_width = 1 if "beam_width" not in self.model_params else self.model_params["beam_width"]
          tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=self._beam_width)
          tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(enc_src_lengths, multiplier=self._beam_width)
          attention_mechanism = self._build_attention(tiled_enc_outputs, tiled_enc_src_lengths)
          attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
                                                                      attention_mechanism=attention_mechanism,
                                                                      cell_input_fn=attn_decoder_custom_fn)
          batch_size_tensor = tf.constant(batch_size)
          decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell=attentive_decoder_cell,
            embedding=self._tgt_w,
            start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
            end_token=END_SYMBOL,
            initial_state=attentive_decoder_cell.zero_state(dtype=getdtype(),
                                                            batch_size=batch_size_tensor * self._beam_width),
            beam_width=self._beam_width,
            output_layer=output_layer,
            length_penalty_weight=self._length_penalty_weight)
        else:
          attention_mechanism = self._build_attention(encoder_outputs, enc_src_lengths)
          attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
                                                                      attention_mechanism=attention_mechanism,
                                                                      cell_input_fn=attn_decoder_custom_fn)
          helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            embedding=self._tgt_w,
            start_tokens=tf.fill([batch_size], GO_SYMBOL),
            end_token=END_SYMBOL)
          decoder = tf.contrib.seq2seq.BasicDecoder(
            cell=attentive_decoder_cell,
            helper=helper,
            initial_state=attentive_decoder_cell.zero_state(batch_size=batch_size, dtype=getdtype()),
            output_layer=output_layer)
      elif self.mode == "train" or self.mode == "eval":
        attention_mechanism = self._build_attention(encoder_outputs, enc_src_lengths)
        attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
                                                                    attention_mechanism=attention_mechanism,
                                                                    cell_input_fn=attn_decoder_custom_fn)
        input_vectors = tf.nn.embedding_lookup(self._tgt_w, tgt_inputs)
        helper = tf.contrib.seq2seq.TrainingHelper(
          inputs = input_vectors,
          sequence_length = tgt_lengths)

        decoder = tf.contrib.seq2seq.BasicDecoder(
          cell=attentive_decoder_cell,
          helper=helper,
          output_layer=output_layer,
          initial_state=attentive_decoder_cell.zero_state(batch_size, dtype=getdtype()))
      else:
        raise NotImplementedError("Unknown mode")

      final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
        decoder = decoder,
        impute_finished=False if self._decoder_type == "beam_search" else True,
        maximum_iterations=tf.reduce_max(tgt_lengths) if self._mode == 'train' else tf.reduce_max(enc_src_lengths)*2,
        swap_memory = False if 'use_swap_memory' not in self.model_params else self.model_params['use_swap_memory'])

      return final_outputs, final_state, final_sequence_lengths
Beispiel #25
0
  def _testWithMaybeMultiAttention(self,
                                   is_multi,
                                   create_attention_mechanisms,
                                   expected_final_output,
                                   expected_final_state,
                                   attention_mechanism_depths,
                                   alignment_history=False,
                                   expected_final_alignment_history=None,
                                   attention_layer_sizes=None,
                                   name=''):
    # Allow is_multi to be True with a single mechanism to enable test for
    # passing in a single mechanism in a list.
    assert len(create_attention_mechanisms) == 1 or is_multi
    encoder_sequence_length = [3, 2, 3, 1, 1]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9

    if attention_layer_sizes is None:
      attention_depth = encoder_output_depth * len(create_attention_mechanisms)
    else:
      # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
      attention_depth = sum([attention_layer_size or encoder_output_depth
                             for attention_layer_size in attention_layer_sizes])

    decoder_inputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, decoder_max_time,
                        input_depth).astype(np.float32),
        shape=(None, None, input_depth))
    encoder_outputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, encoder_max_time,
                        encoder_output_depth).astype(np.float32),
        shape=(None, None, encoder_output_depth))

    attention_mechanisms = [
        creator(num_units=depth,
                memory=encoder_outputs,
                memory_sequence_length=encoder_sequence_length)
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths)]

    with self.test_session(use_gpu=True) as sess:
      with vs.variable_scope(
          'root',
          initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
        cell = rnn_cell.LSTMCell(cell_depth)
        cell = wrapper.AttentionWrapper(
            cell,
            attention_mechanisms if is_multi else attention_mechanisms[0],
            attention_layer_size=(attention_layer_sizes if is_multi
                                  else attention_layer_sizes[0]),
            alignment_history=alignment_history)
        helper = helper_py.TrainingHelper(decoder_inputs,
                                          decoder_sequence_length)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))

        final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))

      self.assertEqual((batch_size, None, attention_depth),
                       tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual((batch_size, None),
                       tuple(final_outputs.sample_id.get_shape().as_list()))

      self.assertEqual((batch_size, attention_depth),
                       tuple(final_state.attention.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.c.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.h.get_shape().as_list()))

      if alignment_history:
        if is_multi:
          state_alignment_history = []
          for history_array in final_state.alignment_history:
            history = history_array.stack()
            self.assertEqual(
                (None, batch_size, None),
                tuple(history.get_shape().as_list()))
            state_alignment_history.append(history)
          state_alignment_history = tuple(state_alignment_history)
        else:
          state_alignment_history = final_state.alignment_history.stack()
          self.assertEqual(
              (None, batch_size, None),
              tuple(state_alignment_history.get_shape().as_list()))
        nest.assert_same_structure(
            cell.state_size,
            cell.zero_state(batch_size, dtypes.float32))
        # Remove the history from final_state for purposes of the
        # remainder of the tests.
        final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
      else:
        state_alignment_history = ()

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          'final_outputs': final_outputs,
          'final_state': final_state,
          'state_alignment_history': state_alignment_history,
      })

      final_output_info = nest.map_structure(get_result_summary,
                                             sess_results['final_outputs'])
      final_state_info = nest.map_structure(get_result_summary,
                                            sess_results['final_state'])
      print(name)
      print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info))
      print('expected_final_state = %s' % str(final_state_info))
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_output,
                         final_output_info)
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_state,
                         final_state_info)
      if alignment_history:  # by default, the wrapper emits attention as output
        final_alignment_history_info = nest.map_structure(
            get_result_summary, sess_results['state_alignment_history'])
        print('expected_final_alignment_history = %s' %
              str(final_alignment_history_info))
        nest.map_structure(
            self.assertAllCloseOrEqual,
            # outputs are batch major but the stacked TensorArray is time major
            expected_final_alignment_history,
            final_alignment_history_info)
Beispiel #26
0
    def _build_decoder(self, encoder_outputs, encoder_state):
        with tf.name_scope("seq_decoder"):
            batch_size = self.batch_size
            # sequence_length = tf.fill([self.batch_size], self.num_steps)
            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                sequence_length = self.iterator.target_length
            else:
                sequence_length = self.iterator.source_length
            if (self.mode !=
                    tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1:
                batch_size = batch_size * self.beam_width
                encoder_outputs = beam_search_decoder.tile_batch(
                    encoder_outputs, multiplier=self.beam_width)
                encoder_state = nest.map_structure(
                    lambda s: beam_search_decoder.tile_batch(
                        s, self.beam_width), encoder_state)
                sequence_length = beam_search_decoder.tile_batch(
                    sequence_length, multiplier=self.beam_width)

            single_cell = single_rnn_cell(self.hparams.unit_type,
                                          self.num_units, self.dropout)
            decoder_cell = MultiRNNCell(
                [single_cell for _ in range(self.num_layers_decoder)])
            decoder_cell = InputProjectionWrapper(decoder_cell,
                                                  num_proj=self.num_units)
            attention_mechanism = create_attention_mechanism(
                self.hparams.attention_mechanism,
                self.num_units,
                memory=encoder_outputs,
                source_sequence_length=sequence_length)
            decoder_cell = wrapper.AttentionWrapper(
                decoder_cell,
                attention_mechanism,
                attention_layer_size=self.num_units,
                output_attention=True,
                alignment_history=False)

            # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다.
            initial_state = decoder_cell.zero_state(batch_size=batch_size,
                                                    dtype=tf.float32)
            embeddings_decoder = tf.get_variable(
                "embedding_decoder",
                [self.num_decoder_symbols, self.num_units],
                initializer=self.initializer,
                dtype=tf.float32)
            output_layer = Dense(units=self.num_decoder_symbols,
                                 use_bias=True,
                                 name="output_layer")

            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                decoder_inputs = tf.nn.embedding_lookup(
                    embeddings_decoder, self.iterator.target_in)
                decoder_helper = helper.TrainingHelper(
                    decoder_inputs, sequence_length=sequence_length)

                dec = basic_decoder.BasicDecoder(decoder_cell,
                                                 decoder_helper,
                                                 initial_state,
                                                 output_layer=output_layer)
                final_outputs, final_state, _ = decoder.dynamic_decode(dec)
                output_ids = final_outputs.rnn_output
                outputs = final_outputs.sample_id
            else:

                def embedding_fn(inputs):
                    return tf.nn.embedding_lookup(embeddings_decoder, inputs)

                decoding_length_factor = 2.0
                max_encoder_length = tf.reduce_max(self.iterator.source_length)
                maximum_iterations = tf.to_int32(
                    tf.round(
                        tf.to_float(max_encoder_length) *
                        decoding_length_factor))

                tgt_sos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)),
                    tf.int32)
                tgt_eos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)),
                    tf.int32)
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id

                if self.beam_width == 1:
                    decoder_helper = helper.GreedyEmbeddingHelper(
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token)
                    dec = basic_decoder.BasicDecoder(decoder_cell,
                                                     decoder_helper,
                                                     initial_state,
                                                     output_layer=output_layer)
                else:
                    dec = beam_search_decoder.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=initial_state,
                        output_layer=output_layer,
                        beam_width=self.beam_width)
                final_outputs, final_state, _ = decoder.dynamic_decode(
                    dec,
                    # swap_memory=True,
                    maximum_iterations=maximum_iterations)
                if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1:
                    output_ids = final_outputs.sample_id
                    outputs = final_outputs.rnn_output
                else:
                    output_ids = final_outputs.predicted_ids
                    outputs = final_outputs.beam_search_decoder_output.scores

            return output_ids, outputs
    def _testDynamicDecodeRNN(self, time_major, has_attention):
        encoder_sequence_length = np.array([3, 2, 3, 1, 1])
        decoder_sequence_length = np.array([2, 0, 1, 2, 3])
        batch_size = 5
        decoder_max_time = 4
        input_depth = 7
        cell_depth = 9
        attention_depth = 6
        vocab_size = 20
        end_token = vocab_size - 1
        start_token = 0
        embedding_dim = 50
        max_out = max(decoder_sequence_length)
        output_layer = layers_core.Dense(vocab_size,
                                         use_bias=True,
                                         activation=None)
        beam_width = 3

        with self.test_session() as sess:
            batch_size_tensor = constant_op.constant(batch_size)
            embedding = np.random.randn(vocab_size,
                                        embedding_dim).astype(np.float32)
            cell = rnn_cell.LSTMCell(cell_depth)
            initial_state = cell.zero_state(batch_size, dtypes.float32)
            if has_attention:
                inputs = array_ops.placeholder_with_default(
                    np.random.randn(batch_size, decoder_max_time,
                                    input_depth).astype(np.float32),
                    shape=(None, None, input_depth))
                tiled_inputs = beam_search_decoder.tile_batch(
                    inputs, multiplier=beam_width)
                tiled_sequence_length = beam_search_decoder.tile_batch(
                    encoder_sequence_length, multiplier=beam_width)
                attention_mechanism = attention_wrapper.BahdanauAttention(
                    num_units=attention_depth,
                    memory=tiled_inputs,
                    memory_sequence_length=tiled_sequence_length)
                initial_state = beam_search_decoder.tile_batch(
                    initial_state, multiplier=beam_width)
                cell = attention_wrapper.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=attention_depth,
                    alignment_history=False)
            cell_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size_tensor *
                                         beam_width)
            if has_attention:
                cell_state = cell_state.clone(cell_state=initial_state)
            bsd = beam_search_decoder.BeamSearchDecoder(
                cell=cell,
                embedding=embedding,
                start_tokens=array_ops.fill([batch_size_tensor], start_token),
                end_token=end_token,
                initial_state=cell_state,
                beam_width=beam_width,
                output_layer=output_layer,
                length_penalty_weight=0.0)

            final_outputs, final_state, final_sequence_lengths = (
                decoder.dynamic_decode(bsd,
                                       output_time_major=time_major,
                                       maximum_iterations=max_out))

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertTrue(
                isinstance(final_outputs,
                           beam_search_decoder.FinalBeamSearchDecoderOutput))
            self.assertTrue(
                isinstance(final_state,
                           beam_search_decoder.BeamSearchDecoderState))

            beam_search_decoder_output = final_outputs.beam_search_decoder_output
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(beam_search_decoder_output.scores.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(final_outputs.predicted_ids.get_shape().as_list()))

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                'final_outputs':
                final_outputs,
                'final_state':
                final_state,
                'final_sequence_lengths':
                final_sequence_lengths
            })

            max_sequence_length = np.max(
                sess_results['final_sequence_lengths'])

            # A smoke test
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                scores.shape)
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                predicted_ids.shape)
Beispiel #28
0
from tensorflow.python.ops import rnn_cell
#
# tf.enable_eager_execution()
batch_size = 5
src_len = [4, 5, 3, 5, 6]
max_times = 6
num_units = 16
enc_output = tf.random.normal((batch_size, max_times, num_units),
                              dtype=tf.float32)
#
# attenRNNCell
rnncell = rnn_cell.LSTMCell(num_units=16)
attention_mechanism = attention_wrapper.BahdanauAttention(
    num_units=num_units, memory=enc_output, memory_sequence_length=src_len)
attnRNNCell = attention_wrapper.AttentionWrapper(
    cell=rnncell,
    attention_mechanism=attention_mechanism,
    alignment_history=True)

# training
tgt_len = [5, 6, 2, 7, 4]
tgt_max_times = 7
tgt_inputs = tf.random.normal((batch_size, tgt_max_times, num_units),
                              dtype=tf.float32)
training_helper = helper_py.TrainingHelper(tgt_inputs, tgt_len)

# train helper
train_decoder = basic_decoder.BasicDecoder(
    cell=attnRNNCell,
    helper=training_helper,
    initial_state=attnRNNCell.zero_state(batch_size, tf.float32))