def testLuongScaledDType(self):
    # Test case for GitHub issue 18099
    for dtype in [np.float16, np.float32, np.float64]:
      num_units = 128
      encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
      encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
      decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      batch_size = 64
      attention_mechanism = wrapper.LuongAttention(
          num_units=num_units,
          memory=encoder_outputs,
          memory_sequence_length=encoder_sequence_length,
          scale=True,
          dtype=dtype,
      )
      cell = rnn_cell.LSTMCell(num_units)
      cell = wrapper.AttentionWrapper(cell, attention_mechanism)

      helper = helper_py.TrainingHelper(decoder_inputs,
                                        decoder_sequence_length)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtype, batch_size=batch_size))

      final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual(final_outputs.rnn_output.dtype, dtype)
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
  def testLuongScaledDType(self):
    # Test case for GitHub issue 18099
    for dtype in [np.float16, np.float32, np.float64]:
      num_units = 128
      encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
      encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
      decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      batch_size = 64
      attention_mechanism = wrapper.LuongAttention(
          num_units=num_units,
          memory=encoder_outputs,
          memory_sequence_length=encoder_sequence_length,
          scale=True,
          dtype=dtype,
      )
      cell = rnn_cell.LSTMCell(num_units)
      cell = wrapper.AttentionWrapper(cell, attention_mechanism)

      helper = helper_py.TrainingHelper(decoder_inputs,
                                        decoder_sequence_length)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtype, batch_size=batch_size))

      final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual(final_outputs.rnn_output.dtype, dtype)
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
Ejemplo n.º 3
0
    def build_decoder(self, encoder_outputs, encoder_final_state, hparams):
        decoder_num_hidden = hparams.decoder_num_hidden
        num_alpha = hparams.num_alpha
        with tf.variable_scope("decoder"):
            cell = tf.nn.rnn_cell.GRUCell(decoder_num_hidden)
            cell = AttentionWrapper(cell, BahdanauAttention(decoder_num_hidden, encoder_outputs,
                                                       memory_sequence_length=self.inputs.feature_length,normalize=True
                                                       ),
                                   initial_cell_state=encoder_final_state, output_attention=True,
                                   attention_layer_size=decoder_num_hidden, alignment_history=True)
            #cell = LocalAttentionWrapper(cell, LocalGaussAttention(decoder_num_hidden, encoder_outputs,
            #                                                       self.inputs.feature_length), decoder_num_hidden,
            #                             initial_cell_state=encoder_final_state)
            dense_layer = Dense(num_alpha, dtype=tf.float32, use_bias=False)
            with tf.variable_scope("train"):
                target_input = tf.one_hot(self.inputs.target_input, hparams.num_alpha)
                target_output = tf.one_hot(self.inputs.target_output, hparams.num_alpha)
                helper = TrainingHelper(target_input,
                                        tf.cast(self.inputs.target_length, tf.int32))
                decoder = BasicDecoder(cell, helper, cell.zero_state(self.batch_size, tf.float32), dense_layer)
                decoder_outputs, final_state, _1 = dynamic_decode(decoder, impute_finished=True)
                tf.summary.image('alignment_history', general.create_alignment_image(final_state))
                logits = decoder_outputs.rnn_output
                max_time = get_max_time(target_output, 1)
                target_weights = tf.sequence_mask(
                    self.inputs.target_length, max_time, dtype=logits.dtype)
                self.loss = tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits(labels=target_output,
                                                            logits=logits) * target_weights)
                grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tf.trainable_variables()), 1)
                self.update = tf.train.AdamOptimizer().apply_gradients(zip(grads, tf.trainable_variables()))
            with tf.variable_scope("eval"):
                def embedding(ids):
                    vec = tf.one_hot(ids, num_alpha, dtype=tf.float32)
                    return vec

                start_tokens = tf.fill([self.batch_size], hparams.sos_id)
                end_tokens = hparams.eos_id
                greedy_helper = GreedyEmbeddingHelper(embedding, start_tokens, end_tokens)
                pre_decoder = BasicDecoder(cell, greedy_helper, cell.zero_state(self.batch_size, tf.float32),
                                           dense_layer)
                pre_decoder_outputs, _, _1 = dynamic_decode(pre_decoder, impute_finished=True, maximum_iterations=50)
                self.predict_id = tf.cast(pre_decoder_outputs.sample_id, tf.int64)
                self.predict_string = self.index_to_string_table.lookup(self.predict_id)
Ejemplo n.º 4
0
    def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
            self, use_sequence_length):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        max_out = max(sequence_length)

        with self.test_session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)

            cell = core_rnn_cell.LSTMCell(cell_depth)
            zero_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size)
            helper = helper_py.TrainingHelper(inputs, sequence_length)
            my_decoder = basic_decoder.BasicDecoder(cell=cell,
                                                    helper=helper,
                                                    initial_state=zero_state)

            # Match the variable scope of dynamic_rnn below so we end up
            # using the same variables
            with vs.variable_scope("root") as scope:
                final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode(
                    my_decoder,
                    # impute_finished=True ensures outputs and final state
                    # match those of dynamic_rnn called with sequence_length not None
                    impute_finished=use_sequence_length,
                    scope=scope)

            with vs.variable_scope(scope, reuse=True) as scope:
                final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn(
                    cell,
                    inputs,
                    sequence_length=sequence_length
                    if use_sequence_length else None,
                    initial_state=zero_state,
                    scope=scope)

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_decoder_outputs": final_decoder_outputs,
                "final_decoder_state": final_decoder_state,
                "final_rnn_outputs": final_rnn_outputs,
                "final_rnn_state": final_rnn_state
            })

            # Decoder only runs out to max_out; ensure values are identical
            # to dynamic_rnn, which also zeros out outputs and passes along state.
            self.assertAllClose(
                sess_results["final_decoder_outputs"].rnn_output,
                sess_results["final_rnn_outputs"][:, 0:max_out, :])
            if use_sequence_length:
                self.assertAllClose(sess_results["final_decoder_state"],
                                    sess_results["final_rnn_state"])
Ejemplo n.º 5
0
    def beam_decoder_model(self,
                           context,
                           cell,
                           embedding,
                           output_layer,
                           beam_width,
                           mode=0):

        if hasattr(self, 'beam_output') and mode == 0:
            return self.beam_output

        if hasattr(self, 'beam_output_scores') and mode == 1:
            return self.outputs_manual

        end_token = -1
        start_tokens = tf.cast(self.y_past[:, -1], tf.int32)
        start_tokens = start_tokens[:, 0]

        cell_state = tf.contrib.seq2seq.tile_batch(context,
                                                   multiplier=beam_width)

        bsd = beam_search_decoder.BeamSearchDecoder(cell=cell,
                                                    embedding=embedding,
                                                    start_tokens=start_tokens,
                                                    end_token=end_token,
                                                    initial_state=cell_state,
                                                    beam_width=beam_width,
                                                    output_layer=output_layer,
                                                    length_penalty_weight=0.0)

        # (finished, inputs, state) = bsd.initialize()
        # self.outputs_manual = []
        # for i in range(0, 5):
        #     self.outputs_manual.append(state)
        #     (output, state, inputs, finished) = bsd.step(1, inputs, state)

        final_outputs, final_state, final_sequence_lengths = (
            decoder.dynamic_decode(bsd,
                                   output_time_major=False,
                                   maximum_iterations=self.prediction_length))

        beam_search_decoder_output = final_outputs.beam_search_decoder_output

        self.beam_output = beam_search_decoder_output.predicted_ids[:, :, :]
        beam_output_scores = beam_search_decoder_output.scores

        print("Added beam_scores")
        self.beam_output_scores = tf.identity(beam_output_scores,
                                              name="beam_scores")

        if mode == 0:
            return self.beam_output
        else:
            return self.outputs_manual
Ejemplo n.º 6
0
  def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
      self, use_sequence_length):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10
    max_out = max(sequence_length)

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)

      cell = rnn_cell.LSTMCell(cell_depth)
      zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)
      helper = helper_py.TrainingHelper(inputs, sequence_length)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell, helper=helper, initial_state=zero_state)

      # Match the variable scope of dynamic_rnn below so we end up
      # using the same variables
      with vs.variable_scope("root") as scope:
        final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode(
            my_decoder,
            # impute_finished=True ensures outputs and final state
            # match those of dynamic_rnn called with sequence_length not None
            impute_finished=use_sequence_length,
            scope=scope)

      with vs.variable_scope(scope, reuse=True) as scope:
        final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn(
            cell,
            inputs,
            sequence_length=sequence_length if use_sequence_length else None,
            initial_state=zero_state,
            scope=scope)

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "final_decoder_outputs": final_decoder_outputs,
          "final_decoder_state": final_decoder_state,
          "final_rnn_outputs": final_rnn_outputs,
          "final_rnn_state": final_rnn_state
      })

      # Decoder only runs out to max_out; ensure values are identical
      # to dynamic_rnn, which also zeros out outputs and passes along state.
      self.assertAllClose(sess_results["final_decoder_outputs"].rnn_output,
                          sess_results["final_rnn_outputs"][:, 0:max_out, :])
      if use_sequence_length:
        self.assertAllClose(sess_results["final_decoder_state"],
                            sess_results["final_rnn_state"])
Ejemplo n.º 7
0
    def inference(self, inputs, train=True):
        config = self.config

        # extract character representations from embedding
        with tf.variable_scope('embedding', initializer=tf.contrib.layers.xavier_initializer()):
            embedding = tf.get_variable('embedding',
                    shape=(config.vocab_size, config.embed_dim), dtype=tf.float32)
            embedded_inputs = tf.nn.embedding_lookup(embedding, inputs['text'])

        # extract speaker embedding if multi-speaker
        with tf.variable_scope('speaker'):
            if config.num_speakers > 1:
                speaker_embed = tf.get_variable('speaker_embed',
                        shape=(config.num_speakers, config.speaker_embed_dim), dtype=tf.float32)
                speaker_embed = \
                        tf.nn.embedding_lookup(speaker_embed, inputs['speaker'])
            else:
                speaker_embed = None

        # process text input with CBHG module 
        with tf.variable_scope('encoder'):
            pre_out = self.pre_net(embedded_inputs, dropout=config.char_dropout_prob, train=train)
            tf.summary.histogram('pre_net_out', pre_out)

            encoded = ops.CBHG(pre_out, speaker_embed, K=16, c=[128,128,128], gru_units=128)

        # pass through attention based decoder
        with tf.variable_scope('decoder'):
            dec = self.create_decoder(encoded, inputs, speaker_embed, train)
            (seq2seq_output, _), attention_state, _ = \
                    decoder.dynamic_decode(dec, maximum_iterations=config.max_decode_iter)
            self.alignments = tf.transpose(attention_state.alignment_history.stack(), [1,0,2])
            tf.summary.histogram('seq2seq_output', seq2seq_output)

        # use second CBHG module to process mel features into linear spectogram
        with tf.variable_scope('post-process'):
            # reshape to account for r value
            post_input = tf.reshape(seq2seq_output, 
                    (tf.shape(seq2seq_output)[0], -1, config.mel_features))

            output = ops.CBHG(post_input, K=8, c=[128,256,80], gru_units=128)
            output = tf.layers.dense(output, units=config.fft_size)

            # reshape back to r frame representation
            output = tf.reshape(output, (tf.shape(output)[0], -1, config.fft_size*config.r))
            tf.summary.histogram('output', output)

        return seq2seq_output, output
Ejemplo n.º 8
0
    def conv_decoder_infer(self, encoder_output):
        beam_decoder = BeamDecoder(self.params, self.mode, encoder_output,
                                   self.num_charset)
        # As tensorflow does not support initializing variable with tensor
        # in a loop or conditional
        beam_decoder.init_params_in_loop()
        tf.get_variable_scope().reuse_variables()

        outputs, final_state, final_sequence_lengths = dynamic_decode(
            decoder=beam_decoder,
            output_time_major=True,
            impute_finished=False,
            maximum_iterations=self.max_sequence_length,
            scope='decoder')

        return outputs, final_state, final_sequence_lengths
  def _init_decoder(self):
    with tf.variable_scope('Decoder') as scope:

      self.fc_layer = Dense(self.vocab_size)
      
      if self.is_inference:
        self.start_tokens = tf.placeholder(tf.int32,shape=[None],name='start_tokens')
        self.end_token = tf.placeholder(tf.int32,name='end_token')
        
      
        self.helper = seq2seq.GreedyEmbeddingHelper(
            embedding=self.embedding_matrix,
            start_tokens=self.start_tokens,
            end_token=self.end_token
        )
      else:
        self.helper = seq2seq.TrainingHelper(
            inputs=self.decoder_train_inputs_embedded, 
            sequence_length=self.decoder_train_length,
            time_major=True
        )
      
      self.decoder = seq2seq.BasicDecoder(
          cell=self.decoder_cell,
          helper=self.helper,
          initial_state=self.encoder_state,
          output_layer=self.fc_layer
      )
      
      
      (self.decoder_outputs_train,
       self.decoder_state_train,
       self.decoder_context_state_train
       ) = (
           decoder.dynamic_decode(
               self.decoder, 
               output_time_major=True)
      )
      self.logits = self.decoder_outputs_train.rnn_output
      if not self.is_inference:
        self.decoder_prediction_inference = tf.argmax(self.logits, axis=-1, name='decoder_prediction_inference')
      
        self.decoder_prediction_train = tf.argmax(self.decoder_outputs_train.rnn_output, axis=-1, name='decoder_prediction_train')
        
        self._init_optimizer()
      else:
        self.prob = tf.nn.softmax(self.logits)
Ejemplo n.º 10
0
    def inference(self, inputs, train=True):
        config = self.config

        # extract character representations from embedding
        with tf.variable_scope(
                'embedding',
                initializer=tf.contrib.layers.xavier_initializer()):
            embedding = tf.get_variable('embedding',
                                        shape=(config.vocab_size,
                                               config.embed_dim),
                                        dtype=tf.float32)
            embedded_inputs = tf.nn.embedding_lookup(embedding, inputs['text'])

        # process text input with CBHG module
        with tf.variable_scope('encoder'):
            pre_out = self.pre_net(embedded_inputs, train=train)
            tf.summary.histogram('pre_net_out', pre_out)

            encoded = ops.CBHG(pre_out, K=16, c=[128, 128, 128], gru_units=128)

        # pass through attention based decoder
        with tf.variable_scope('decoder'):
            dec = self.create_decoder(encoded, inputs, train)
            (seq2seq_output, _), attention_state, _ = \
                    decoder.dynamic_decode(dec, maximum_iterations=config.max_decode_iter)
            self.alignments = tf.transpose(
                attention_state.alignment_history.stack(), [1, 0, 2])
            tf.summary.histogram('seq2seq_output', seq2seq_output)

        # use second CBHG module to process mel features into linear spectogram
        with tf.variable_scope('post-process'):
            # reshape to account for r value
            post_input = tf.reshape(
                seq2seq_output,
                (tf.shape(seq2seq_output)[0], -1, config.mel_features))

            output = ops.CBHG(post_input, K=8, c=[128, 256, 80], gru_units=128)
            output = tf.layers.dense(output, units=config.fft_size)

            # reshape back to r frame representation
            output = tf.reshape(
                output, (tf.shape(output)[0], -1, config.fft_size * config.r))
            tf.summary.histogram('output', output)

        return seq2seq_output, output
Ejemplo n.º 11
0
    helper=training_helper,
    initial_state=attnRNNCell.zero_state(batch_size, tf.float32))

# inference
embedding = tf.get_variable("embedding",
                            shape=(10, 16),
                            initializer=tf.random_uniform_initializer())
infer_helper = helper_py.GreedyEmbeddingHelper(
    embedding=embedding,  # 可以是callable,也可以是embedding矩阵
    start_tokens=tf.zeros([batch_size], dtype=tf.int32),
    end_token=9)
infer_decoder = basic_decoder.BasicDecoder(
    cell=attnRNNCell,
    helper=infer_helper,
    initial_state=attnRNNCell.zero_state(batch_size, tf.float32))
final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode(
    train_decoder, maximum_iterations=False)

print(final_outputs.rnn_output)
print(final_outputs.sample_id)
print(final_state.cell_state)
print(final_sequence_lengths)

print("--------infer-------------")
final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode(
    infer_decoder, maximum_iterations=False)

print(final_outputs.rnn_output)
print(final_outputs.sample_id)
print(final_state.cell_state)
print(final_sequence_lengths)
Ejemplo n.º 12
0
  def _testWithMaybeMultiAttention(self,
                                   is_multi,
                                   create_attention_mechanisms,
                                   expected_final_output,
                                   expected_final_state,
                                   attention_mechanism_depths,
                                   alignment_history=False,
                                   expected_final_alignment_history=None,
                                   attention_layer_sizes=None,
                                   attention_layers=None,
                                   name=''):
    # Allow is_multi to be True with a single mechanism to enable test for
    # passing in a single mechanism in a list.
    assert len(create_attention_mechanisms) == 1 or is_multi
    encoder_sequence_length = [3, 2, 3, 1, 1]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9

    if attention_layer_sizes is not None:
      # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
      attention_depth = sum([attention_layer_size or encoder_output_depth
                             for attention_layer_size in attention_layer_sizes])
    elif attention_layers is not None:
      # Compute sum of attention_layers output depth.
      attention_depth = sum(
          attention_layer.compute_output_shape(
              [batch_size, cell_depth + encoder_output_depth])[-1].value
          for attention_layer in attention_layers)
    else:
      attention_depth = encoder_output_depth * len(create_attention_mechanisms)

    decoder_inputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, decoder_max_time,
                        input_depth).astype(np.float32),
        shape=(None, None, input_depth))
    encoder_outputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, encoder_max_time,
                        encoder_output_depth).astype(np.float32),
        shape=(None, None, encoder_output_depth))

    attention_mechanisms = [
        creator(num_units=depth,
                memory=encoder_outputs,
                memory_sequence_length=encoder_sequence_length)
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths)]

    with self.test_session(use_gpu=True) as sess:
      with vs.variable_scope(
          'root',
          initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
        attention_layer_size = attention_layer_sizes
        attention_layer = attention_layers
        if not is_multi:
          if attention_layer_size is not None:
            attention_layer_size = attention_layer_size[0]
          if attention_layer is not None:
            attention_layer = attention_layer[0]
        cell = rnn_cell.LSTMCell(cell_depth)
        cell = wrapper.AttentionWrapper(
            cell,
            attention_mechanisms if is_multi else attention_mechanisms[0],
            attention_layer_size=attention_layer_size,
            alignment_history=alignment_history,
            attention_layer=attention_layer)
        helper = helper_py.TrainingHelper(decoder_inputs,
                                          decoder_sequence_length)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))

        final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))

      self.assertEqual((batch_size, None, attention_depth),
                       tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual((batch_size, None),
                       tuple(final_outputs.sample_id.get_shape().as_list()))

      self.assertEqual((batch_size, attention_depth),
                       tuple(final_state.attention.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.c.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.h.get_shape().as_list()))

      if alignment_history:
        if is_multi:
          state_alignment_history = []
          for history_array in final_state.alignment_history:
            history = history_array.stack()
            self.assertEqual(
                (None, batch_size, None),
                tuple(history.get_shape().as_list()))
            state_alignment_history.append(history)
          state_alignment_history = tuple(state_alignment_history)
        else:
          state_alignment_history = final_state.alignment_history.stack()
          self.assertEqual(
              (None, batch_size, None),
              tuple(state_alignment_history.get_shape().as_list()))
        nest.assert_same_structure(
            cell.state_size,
            cell.zero_state(batch_size, dtypes.float32))
        # Remove the history from final_state for purposes of the
        # remainder of the tests.
        final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
      else:
        state_alignment_history = ()

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          'final_outputs': final_outputs,
          'final_state': final_state,
          'state_alignment_history': state_alignment_history,
      })

      final_output_info = nest.map_structure(get_result_summary,
                                             sess_results['final_outputs'])
      final_state_info = nest.map_structure(get_result_summary,
                                            sess_results['final_state'])
      print(name)
      print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info))
      print('expected_final_state = %s' % str(final_state_info))
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_output,
                         final_output_info)
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_state,
                         final_state_info)
      if alignment_history:  # by default, the wrapper emits attention as output
        final_alignment_history_info = nest.map_structure(
            get_result_summary, sess_results['state_alignment_history'])
        print('expected_final_alignment_history = %s' %
              str(final_alignment_history_info))
        nest.map_structure(
            self.assertAllCloseOrEqual,
            # outputs are batch major but the stacked TensorArray is time major
            expected_final_alignment_history,
            final_alignment_history_info)
Ejemplo n.º 13
0
def impl(features,mode,hp):
    contexts = features['contexts']  # batch_size,max_con_length(with query),max_sen_length
    context_utterance_length = features['context_utterance_length']  # batch_size,max_con_length
    context_length = features['context_length']  # batch_size
    response_in = features['response_in']  # batch_size,max_res_length(with eos token)
    response_out = features['response_out']  # batch_size, max_res_length (with eos token append before)
    response_mask = features['response_mask']  # batch_size, max_res_length (with eos token append before)

    with tf.variable_scope('embedding_layer') as vs:
        embedding_w = get_embedding_matrix(hp.word_dim,mode,hp.vocab_size)
        contexts = tf.nn.embedding_lookup(embedding_w,contexts,'context_embedding')
        if mode == modekeys.TRAIN or mode == modekeys.EVAL:
            response_in = tf.nn.embedding_lookup(embedding_w, response_in, 'response_in_embedding')

    with tf.variable_scope('utterance_encoding_layer',reuse=tf.AUTO_REUSE) as vs:
        kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=random_seed)
        bias_initializer = tf.zeros_initializer()
        fw_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.word_rnn_num_units, kernel_initializer=kernel_initializer,
                                      bias_initializer=bias_initializer)
        bw_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.word_rnn_num_units, kernel_initializer=kernel_initializer,
                                      bias_initializer=bias_initializer)

        context_t = tf.transpose(contexts, perm=[1, 0, 2, 3])  # max_con_length(with query),batch_size,max_sen_length
        context_utterance_length_t = tf.transpose(context_utterance_length, perm=[1, 0])  # max_con_length, batch_size
        a = tf.split(context_t, hp.max_context_length, axis=0)  # 1,batch_size,max_sen_length
        b = tf.split(context_utterance_length_t, hp.max_context_length, axis=0)  # 1,batch_size

        utterance_encodings = []
        for utterance,length in zip(a,b):
            utterance = tf.squeeze(utterance,axis=0)
            length = tf.squeeze(length,axis=0)
            utterance_hidden_states,_ =tf.nn.bidirectional_dynamic_rnn(fw_cell,bw_cell,utterance,sequence_length=length,initial_state_fw=fw_cell.zero_state(hp.batch_size,tf.float32),initial_state_bw=fw_cell.zero_state(hp.batch_size,tf.float32))
            utterance_encoding = tf.concat(utterance_hidden_states,axis=2)
            utterance_encodings.append(tf.expand_dims(utterance_encoding,axis=0))

        utterance_encodings = tf.concat(utterance_encodings, axis=0)  # max_con_length,batch_size,max_sen,2*word_rnn_num_units

    with tf.variable_scope('hierarchical_attention_layer',reuse=tf.AUTO_REUSE) as vs:
        attention_mechanism = ContextAttentionMechanism(context_num_units=100,context=utterance_encodings,context_utterance_length=context_utterance_length_t,max_context_length=hp.max_context_length,context_rnn_num_units=hp.context_rnn_num_units,context_actual_length=context_length)

    with tf.variable_scope('decoder_layer',reuse=tf.AUTO_REUSE) as vs:
        kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=random_seed)
        bias_initializer = tf.zeros_initializer()
        decoder_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.decoder_rnn_num_units, kernel_initializer=kernel_initializer,
                                              bias_initializer=bias_initializer)

        sequence_length = tf.constant(value=hp.max_sentence_length, dtype=tf.int32, shape=[hp.batch_size])
        if mode == modekeys.TRAIN:
            helper = tf.contrib.seq2seq.TrainingHelper(inputs=response_in, sequence_length=sequence_length)
        elif mode == modekeys.EVAL:
            helper = tf.contrib.seq2seq.TrainingHelper(inputs=response_in, sequence_length=sequence_length)
        else:
            start_tokens = tf.constant(value=1, dtype=tf.int32, shape=[hp.batch_size], name='start_tokens')
            end_token = 1
            helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embedding_w, start_tokens=start_tokens,
                                                              end_token=end_token)

        attn_cell = AttentionWrapper(decoder_cell,attention_mechanism=attention_mechanism,attention_layer_size=None,output_attention=False) # output_attention should be False
        output_layer = layers_core.Dense(units=hp.vocab_size,activation=None,use_bias=False)  # should use no activation and no bias
        decoder = BasicDecoder(cell=attn_cell,helper=helper,initial_state=attn_cell.zero_state(hp.batch_size,tf.float32),output_layer=output_layer)

        if mode == modekeys.TRAIN:
            final_outputs, final_state, final_sequence_lengths = dynamic_decode(decoder=decoder,impute_finished=True,parallel_iterations=32,swap_memory=True)
            logits = final_outputs.rnn_output
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=response_out,logits=logits)
            cross_entropy = tf.reduce_sum(cross_entropy*response_mask,axis=1)
            loss = tf.reduce_mean(cross_entropy)
            return loss
        elif mode == modekeys.EVAL:
            final_outputs, final_state, final_sequence_lengths = dynamic_decode(decoder=decoder, impute_finished=True,
                                                                                parallel_iterations=32,
                                                                                swap_memory=True)
            logits = final_outputs.rnn_output
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=response_out, logits=logits)
            cross_entropy = tf.reduce_mean(cross_entropy*response_mask)
            ppl = tf.exp(cross_entropy)
            return ppl
        else:
            max_iter = tf.constant(2*hp.max_sentence_length,dtype=tf.int32,shape=[])
            final_outputs, final_state, final_sequence_lengths = dynamic_decode(decoder=decoder, impute_finished=True,
                                                                                parallel_iterations=32,
                                                                                swap_memory=True,maximum_iterations=max_iter)
            return final_outputs.sample_id, final_sequence_lengths #batch, T T = max(final_sequence_lengths)
Ejemplo n.º 14
0
    def _build_decoder(self, encoder_outputs, encoder_state):
        with tf.name_scope("seq_decoder"):
            batch_size = self.batch_size
            # sequence_length = tf.fill([self.batch_size], self.num_steps)
            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                sequence_length = self.iterator.target_length
            else:
                sequence_length = self.iterator.source_length
            if (self.mode !=
                    tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1:
                batch_size = batch_size * self.beam_width
                encoder_outputs = beam_search_decoder.tile_batch(
                    encoder_outputs, multiplier=self.beam_width)
                encoder_state = nest.map_structure(
                    lambda s: beam_search_decoder.tile_batch(
                        s, self.beam_width), encoder_state)
                sequence_length = beam_search_decoder.tile_batch(
                    sequence_length, multiplier=self.beam_width)

            single_cell = single_rnn_cell(self.hparams.unit_type,
                                          self.num_units, self.dropout)
            decoder_cell = MultiRNNCell(
                [single_cell for _ in range(self.num_layers_decoder)])
            decoder_cell = InputProjectionWrapper(decoder_cell,
                                                  num_proj=self.num_units)
            attention_mechanism = create_attention_mechanism(
                self.hparams.attention_mechanism,
                self.num_units,
                memory=encoder_outputs,
                source_sequence_length=sequence_length)
            decoder_cell = wrapper.AttentionWrapper(
                decoder_cell,
                attention_mechanism,
                attention_layer_size=self.num_units,
                output_attention=True,
                alignment_history=False)

            # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다.
            initial_state = decoder_cell.zero_state(batch_size=batch_size,
                                                    dtype=tf.float32)
            embeddings_decoder = tf.get_variable(
                "embedding_decoder",
                [self.num_decoder_symbols, self.num_units],
                initializer=self.initializer,
                dtype=tf.float32)
            output_layer = Dense(units=self.num_decoder_symbols,
                                 use_bias=True,
                                 name="output_layer")

            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                decoder_inputs = tf.nn.embedding_lookup(
                    embeddings_decoder, self.iterator.target_in)
                decoder_helper = helper.TrainingHelper(
                    decoder_inputs, sequence_length=sequence_length)

                dec = basic_decoder.BasicDecoder(decoder_cell,
                                                 decoder_helper,
                                                 initial_state,
                                                 output_layer=output_layer)
                final_outputs, final_state, _ = decoder.dynamic_decode(dec)
                output_ids = final_outputs.rnn_output
                outputs = final_outputs.sample_id
            else:

                def embedding_fn(inputs):
                    return tf.nn.embedding_lookup(embeddings_decoder, inputs)

                decoding_length_factor = 2.0
                max_encoder_length = tf.reduce_max(self.iterator.source_length)
                maximum_iterations = tf.to_int32(
                    tf.round(
                        tf.to_float(max_encoder_length) *
                        decoding_length_factor))

                tgt_sos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)),
                    tf.int32)
                tgt_eos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)),
                    tf.int32)
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id

                if self.beam_width == 1:
                    decoder_helper = helper.GreedyEmbeddingHelper(
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token)
                    dec = basic_decoder.BasicDecoder(decoder_cell,
                                                     decoder_helper,
                                                     initial_state,
                                                     output_layer=output_layer)
                else:
                    dec = beam_search_decoder.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=initial_state,
                        output_layer=output_layer,
                        beam_width=self.beam_width)
                final_outputs, final_state, _ = decoder.dynamic_decode(
                    dec,
                    # swap_memory=True,
                    maximum_iterations=maximum_iterations)
                if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1:
                    output_ids = final_outputs.sample_id
                    outputs = final_outputs.rnn_output
                else:
                    output_ids = final_outputs.predicted_ids
                    outputs = final_outputs.beam_search_decoder_output.scores

            return output_ids, outputs
Ejemplo n.º 15
0
  def _testDynamicDecodeRNN(self, time_major, has_attention,
                            with_alignment_history=False):
    encoder_sequence_length = np.array([3, 2, 3, 1, 1])
    decoder_sequence_length = np.array([2, 0, 1, 2, 3])
    batch_size = 5
    decoder_max_time = 4
    input_depth = 7
    cell_depth = 9
    attention_depth = 6
    vocab_size = 20
    end_token = vocab_size - 1
    start_token = 0
    embedding_dim = 50
    max_out = max(decoder_sequence_length)
    output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None)
    beam_width = 3

    with self.cached_session() as sess:
      batch_size_tensor = constant_op.constant(batch_size)
      embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      initial_state = cell.zero_state(batch_size, dtypes.float32)
      coverage_penalty_weight = 0.0
      if has_attention:
        coverage_penalty_weight = 0.2
        inputs = array_ops.placeholder_with_default(
            np.random.randn(batch_size, decoder_max_time, input_depth).astype(
                np.float32),
            shape=(None, None, input_depth))
        tiled_inputs = beam_search_decoder.tile_batch(
            inputs, multiplier=beam_width)
        tiled_sequence_length = beam_search_decoder.tile_batch(
            encoder_sequence_length, multiplier=beam_width)
        attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=attention_depth,
            memory=tiled_inputs,
            memory_sequence_length=tiled_sequence_length)
        initial_state = beam_search_decoder.tile_batch(
            initial_state, multiplier=beam_width)
        cell = attention_wrapper.AttentionWrapper(
            cell=cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=attention_depth,
            alignment_history=with_alignment_history)
      cell_state = cell.zero_state(
          dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
      if has_attention:
        cell_state = cell_state.clone(cell_state=initial_state)
      bsd = beam_search_decoder.BeamSearchDecoder(
          cell=cell,
          embedding=embedding,
          start_tokens=array_ops.fill([batch_size_tensor], start_token),
          end_token=end_token,
          initial_state=cell_state,
          beam_width=beam_width,
          output_layer=output_layer,
          length_penalty_weight=0.0,
          coverage_penalty_weight=coverage_penalty_weight)

      final_outputs, final_state, final_sequence_lengths = (
          decoder.dynamic_decode(
              bsd, output_time_major=time_major, maximum_iterations=max_out))

      def _t(shape):
        if time_major:
          return (shape[1], shape[0]) + shape[2:]
        return shape

      self.assertTrue(
          isinstance(final_outputs,
                     beam_search_decoder.FinalBeamSearchDecoderOutput))
      self.assertTrue(
          isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))

      beam_search_decoder_output = final_outputs.beam_search_decoder_output
      self.assertEqual(
          _t((batch_size, None, beam_width)),
          tuple(beam_search_decoder_output.scores.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None, beam_width)),
          tuple(final_outputs.predicted_ids.get_shape().as_list()))

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          'final_outputs': final_outputs,
          'final_state': final_state,
          'final_sequence_lengths': final_sequence_lengths
      })

      max_sequence_length = np.max(sess_results['final_sequence_lengths'])

      # A smoke test
      self.assertEqual(
          _t((batch_size, max_sequence_length, beam_width)),
          sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
      self.assertEqual(
          _t((batch_size, max_sequence_length, beam_width)), sess_results[
              'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
Ejemplo n.º 16
0
  def _testDynamicDecodeRNN(self, time_major, has_attention):
    encoder_sequence_length = [3, 2, 3, 1, 0]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    decoder_max_time = 4
    input_depth = 7
    cell_depth = 9
    attention_depth = 6
    vocab_size = 20
    end_token = vocab_size - 1
    start_token = 0
    embedding_dim = 50
    max_out = max(decoder_sequence_length)
    output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None)
    beam_width = 3

    with self.test_session() as sess:
      embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
      cell = core_rnn_cell.LSTMCell(cell_depth)
      if has_attention:
        inputs = np.random.randn(batch_size, decoder_max_time,
                                 input_depth).astype(np.float32)
        attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=attention_depth,
            memory=inputs,
            memory_sequence_length=encoder_sequence_length)
        cell = attention_wrapper.AttentionWrapper(
            cell=cell,
            attention_mechanism=attention_mechanism,
            attention_size=attention_depth,
            alignment_history=False)
      cell_state = cell.zero_state(
          dtype=dtypes.float32, batch_size=batch_size * beam_width)
      bsd = beam_search_decoder.BeamSearchDecoder(
          cell=cell,
          embedding=embedding,
          start_tokens=batch_size * [start_token],
          end_token=end_token,
          initial_state=cell_state,
          beam_width=beam_width,
          output_layer=output_layer,
          length_penalty_weight=0.0)

      final_outputs, final_state = decoder.dynamic_decode(
          bsd, output_time_major=time_major, maximum_iterations=max_out)

      def _t(shape):
        if time_major:
          return (shape[1], shape[0]) + shape[2:]
        return shape

      self.assertTrue(
          isinstance(final_outputs,
                     beam_search_decoder.FinalBeamSearchDecoderOutput))
      self.assertTrue(
          isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))

      beam_search_decoder_output = final_outputs.beam_search_decoder_output
      self.assertEqual(
          _t((batch_size, None, beam_width)),
          tuple(beam_search_decoder_output.scores.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None, beam_width)),
          tuple(final_outputs.predicted_ids.get_shape().as_list()))

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          'final_outputs': final_outputs,
          'final_state': final_state
      })

      # Mostly a smoke test
      time_steps = max_out
      self.assertEqual(
          _t((batch_size, time_steps, beam_width)),
          sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
      self.assertEqual(
          _t((batch_size, time_steps, beam_width)), sess_results[
              'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
Ejemplo n.º 17
0
    def _testWithAttention(self,
                           create_attention_mechanism,
                           expected_final_output,
                           expected_final_state,
                           attention_mechanism_depth=3,
                           alignment_history=False,
                           expected_final_alignment_history=None,
                           attention_layer_size=6,
                           name=''):
        encoder_sequence_length = [3, 2, 3, 1, 1]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9

        if attention_layer_size is not None:
            attention_depth = attention_layer_size
        else:
            attention_depth = encoder_output_depth

        decoder_inputs = array_ops.placeholder_with_default(
            np.random.randn(batch_size, decoder_max_time,
                            input_depth).astype(np.float32),
            shape=(None, None, input_depth))
        encoder_outputs = array_ops.placeholder_with_default(
            np.random.randn(batch_size, encoder_max_time,
                            encoder_output_depth).astype(np.float32),
            shape=(None, None, encoder_output_depth))

        attention_mechanism = create_attention_mechanism(
            num_units=attention_mechanism_depth,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)

        with self.test_session(use_gpu=True) as sess:
            with vs.variable_scope(
                    'root',
                    initializer=init_ops.random_normal_initializer(stddev=0.01,
                                                                   seed=3)):
                cell = rnn_cell.LSTMCell(cell_depth)
                cell = wrapper.AttentionWrapper(
                    cell,
                    attention_mechanism,
                    attention_layer_size=attention_layer_size,
                    alignment_history=alignment_history)
                helper = helper_py.TrainingHelper(decoder_inputs,
                                                  decoder_sequence_length)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))

                final_outputs, final_state, _ = decoder.dynamic_decode(
                    my_decoder)

            self.assertTrue(
                isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
            self.assertTrue(
                isinstance(final_state, wrapper.AttentionWrapperState))
            self.assertTrue(
                isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))

            self.assertEqual(
                (batch_size, None, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, None),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.c.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.h.get_shape().as_list()))

            if alignment_history:
                state_alignment_history = final_state.alignment_history.stack()
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
                self.assertEqual(
                    (None, batch_size, None),
                    tuple(state_alignment_history.get_shape().as_list()))
            else:
                state_alignment_history = ()

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                'final_outputs':
                final_outputs,
                'final_state':
                final_state,
                'state_alignment_history':
                state_alignment_history,
            })

            final_output_info = nest.map_structure(
                get_result_summary, sess_results['final_outputs'])
            final_state_info = nest.map_structure(get_result_summary,
                                                  sess_results['final_state'])
            print(name)
            print('Copy/paste:\nexpected_final_output = %s' %
                  str(final_output_info))
            print('expected_final_state = %s' % str(final_state_info))
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_output, final_output_info)
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_state, final_state_info)
            if alignment_history:  # by default, the wrapper emits attention as output
                final_alignment_history_info = nest.map_structure(
                    get_result_summary,
                    sess_results['state_alignment_history'])
                print('expected_final_alignment_history = %s' %
                      str(final_alignment_history_info))
                nest.map_structure(
                    self.assertAllCloseOrEqual,
                    # outputs are batch major but the stacked TensorArray is time major
                    expected_final_alignment_history,
                    final_alignment_history_info)
Ejemplo n.º 18
0
  def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):

    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10
    max_out = max(sequence_length)

    with self.session(use_gpu=True) as sess:
      if time_major:
        inputs = np.random.randn(max_time, batch_size,
                                 input_depth).astype(np.float32)
      else:
        inputs = np.random.randn(batch_size, max_time,
                                 input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      helper = helper_py.TrainingHelper(
          inputs, sequence_length, time_major=time_major)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))

      final_outputs, final_state, final_sequence_length = (
          decoder.dynamic_decode(my_decoder, output_time_major=time_major,
                                 maximum_iterations=maximum_iterations))

      def _t(shape):
        if time_major:
          return (shape[1], shape[0]) + shape[2:]
        return shape

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(isinstance(final_state, rnn_cell.LSTMStateTuple))

      self.assertEqual(
          (batch_size,),
          tuple(final_sequence_length.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None, cell_depth)),
          tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None)),
          tuple(final_outputs.sample_id.get_shape().as_list()))

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "final_outputs": final_outputs,
          "final_state": final_state,
          "final_sequence_length": final_sequence_length,
      })

      # Mostly a smoke test
      time_steps = max_out
      expected_length = sequence_length
      if maximum_iterations is not None:
        time_steps = min(max_out, maximum_iterations)
        expected_length = [min(x, maximum_iterations) for x in expected_length]
      self.assertEqual(
          _t((batch_size, time_steps, cell_depth)),
          sess_results["final_outputs"].rnn_output.shape)
      self.assertEqual(
          _t((batch_size, time_steps)),
          sess_results["final_outputs"].sample_id.shape)
      self.assertItemsEqual(expected_length,
                            sess_results["final_sequence_length"])
Ejemplo n.º 19
0
  def _testWithMaybeMultiAttention(self,
                                   is_multi,
                                   create_attention_mechanisms,
                                   expected_final_output,
                                   expected_final_state,
                                   attention_mechanism_depths,
                                   alignment_history=False,
                                   expected_final_alignment_history=None,
                                   attention_layer_sizes=None,
                                   name=''):
    # Allow is_multi to be True with a single mechanism to enable test for
    # passing in a single mechanism in a list.
    assert len(create_attention_mechanisms) == 1 or is_multi
    encoder_sequence_length = [3, 2, 3, 1, 1]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9

    if attention_layer_sizes is None:
      attention_depth = encoder_output_depth * len(create_attention_mechanisms)
    else:
      # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
      attention_depth = sum([attention_layer_size or encoder_output_depth
                             for attention_layer_size in attention_layer_sizes])

    decoder_inputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, decoder_max_time,
                        input_depth).astype(np.float32),
        shape=(None, None, input_depth))
    encoder_outputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, encoder_max_time,
                        encoder_output_depth).astype(np.float32),
        shape=(None, None, encoder_output_depth))

    attention_mechanisms = [
        creator(num_units=depth,
                memory=encoder_outputs,
                memory_sequence_length=encoder_sequence_length)
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths)]

    with self.test_session(use_gpu=True) as sess:
      with vs.variable_scope(
          'root',
          initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
        cell = rnn_cell.LSTMCell(cell_depth)
        cell = wrapper.AttentionWrapper(
            cell,
            attention_mechanisms if is_multi else attention_mechanisms[0],
            attention_layer_size=(attention_layer_sizes if is_multi
                                  else attention_layer_sizes[0]),
            alignment_history=alignment_history)
        helper = helper_py.TrainingHelper(decoder_inputs,
                                          decoder_sequence_length)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))

        final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))

      self.assertEqual((batch_size, None, attention_depth),
                       tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual((batch_size, None),
                       tuple(final_outputs.sample_id.get_shape().as_list()))

      self.assertEqual((batch_size, attention_depth),
                       tuple(final_state.attention.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.c.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.h.get_shape().as_list()))

      if alignment_history:
        if is_multi:
          state_alignment_history = []
          for history_array in final_state.alignment_history:
            history = history_array.stack()
            self.assertEqual(
                (None, batch_size, None),
                tuple(history.get_shape().as_list()))
            state_alignment_history.append(history)
          state_alignment_history = tuple(state_alignment_history)
        else:
          state_alignment_history = final_state.alignment_history.stack()
          self.assertEqual(
              (None, batch_size, None),
              tuple(state_alignment_history.get_shape().as_list()))
        nest.assert_same_structure(
            cell.state_size,
            cell.zero_state(batch_size, dtypes.float32))
        # Remove the history from final_state for purposes of the
        # remainder of the tests.
        final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
      else:
        state_alignment_history = ()

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          'final_outputs': final_outputs,
          'final_state': final_state,
          'state_alignment_history': state_alignment_history,
      })

      final_output_info = nest.map_structure(get_result_summary,
                                             sess_results['final_outputs'])
      final_state_info = nest.map_structure(get_result_summary,
                                            sess_results['final_state'])
      print(name)
      print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info))
      print('expected_final_state = %s' % str(final_state_info))
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_output,
                         final_output_info)
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_state,
                         final_state_info)
      if alignment_history:  # by default, the wrapper emits attention as output
        final_alignment_history_info = nest.map_structure(
            get_result_summary, sess_results['state_alignment_history'])
        print('expected_final_alignment_history = %s' %
              str(final_alignment_history_info))
        nest.map_structure(
            self.assertAllCloseOrEqual,
            # outputs are batch major but the stacked TensorArray is time major
            expected_final_alignment_history,
            final_alignment_history_info)
Ejemplo n.º 20
0
    def build(self):
        conf = self.conf
        name = self.name
        job_type = self.job_type
        dtype = self.dtype

        # Input maps
        self.in_table = lookup.MutableHashTable(key_dtype=tf.string,
                                                value_dtype=tf.int64,
                                                default_value=UNK_ID,
                                                shared_name="in_table",
                                                name="in_table",
                                                checkpoint=True)

        self.topic_in_table = lookup.MutableHashTable(
            key_dtype=tf.string,
            value_dtype=tf.int64,
            default_value=2,
            shared_name="topic_in_table",
            name="topic_in_table",
            checkpoint=True)

        self.out_table = lookup.MutableHashTable(key_dtype=tf.int64,
                                                 value_dtype=tf.string,
                                                 default_value="_UNK",
                                                 shared_name="out_table",
                                                 name="out_table",
                                                 checkpoint=True)

        graphlg.info("Creating placeholders...")
        self.enc_str_inps = tf.placeholder(tf.string,
                                           shape=(None, conf.input_max_len),
                                           name="enc_inps")
        self.enc_lens = tf.placeholder(tf.int32, shape=[None], name="enc_lens")

        self.enc_str_topics = tf.placeholder(tf.string,
                                             shape=(None, None),
                                             name="enc_topics")

        self.dec_str_inps = tf.placeholder(
            tf.string, shape=[None, conf.output_max_len + 2], name="dec_inps")
        self.dec_lens = tf.placeholder(tf.int32, shape=[None], name="dec_lens")

        # table lookup
        self.enc_inps = self.in_table.lookup(self.enc_str_inps)
        self.enc_topics = self.topic_in_table.lookup(self.enc_str_topics)
        self.dec_inps = self.in_table.lookup(self.dec_str_inps)

        batch_size = tf.shape(self.enc_inps)[0]

        with variable_scope.variable_scope(self.model_kind,
                                           dtype=dtype) as scope:
            # Create encode graph and get attn states
            graphlg.info("Creating embeddings and do lookup...")
            t_major_enc_inps = tf.transpose(self.enc_inps)
            with ops.device("/cpu:0"):
                self.embedding = variable_scope.get_variable(
                    "embedding", [conf.input_vocab_size, conf.embedding_size])
                self.emb_enc_inps = embedding_lookup_unique(
                    self.embedding, t_major_enc_inps)
                self.topic_embedding = variable_scope.get_variable(
                    "topic_embedding",
                    [conf.topic_vocab_size, conf.topic_embedding_size],
                    trainable=False)
                self.emb_enc_topics = embedding_lookup_unique(
                    self.topic_embedding, self.enc_topics)

            graphlg.info("Creating out projection weights...")
            if conf.out_layer_size != None:
                w = tf.get_variable(
                    "proj_w", [conf.out_layer_size, conf.output_vocab_size],
                    dtype=dtype)
            else:
                w = tf.get_variable("proj_w",
                                    [conf.num_units, conf.output_vocab_size],
                                    dtype=dtype)
            b = tf.get_variable("proj_b", [conf.output_vocab_size],
                                dtype=dtype)
            self.out_proj = (w, b)

            graphlg.info("Creating encoding dynamic rnn...")
            with variable_scope.variable_scope("encoder",
                                               dtype=dtype) as scope:
                if conf.bidirectional:
                    cell_fw = CreateMultiRNNCell(conf.cell_model,
                                                 conf.num_units,
                                                 conf.num_layers,
                                                 conf.output_keep_prob)
                    cell_bw = CreateMultiRNNCell(conf.cell_model,
                                                 conf.num_units,
                                                 conf.num_layers,
                                                 conf.output_keep_prob)
                    self.enc_outs, self.enc_states = bidirectional_dynamic_rnn(
                        cell_fw=cell_fw,
                        cell_bw=cell_bw,
                        inputs=self.emb_enc_inps,
                        sequence_length=self.enc_lens,
                        dtype=dtype,
                        parallel_iterations=16,
                        time_major=True,
                        scope=scope)
                    fw_s, bw_s = self.enc_states
                    self.enc_states = tuple([
                        tf.concat([f, b], axis=1) for f, b in zip(fw_s, bw_s)
                    ])
                    self.enc_outs = tf.concat(
                        [self.enc_outs[0], self.enc_outs[1]], axis=2)
                else:
                    cell = CreateMultiRNNCell(conf.cell_model, conf.num_units,
                                              conf.num_layers,
                                              conf.output_keep_prob)
                    self.enc_outs, self.enc_states = dynamic_rnn(
                        cell=cell,
                        inputs=self.emb_enc_inps,
                        sequence_length=self.enc_lens,
                        parallel_iterations=16,
                        scope=scope,
                        dtype=dtype,
                        time_major=True)
            attn_len = tf.shape(self.enc_outs)[0]

            graphlg.info("Preparing init attention and states for decoder...")
            initial_state = self.enc_states
            attn_states = tf.transpose(self.enc_outs, perm=[1, 0, 2])
            attn_size = self.conf.num_units
            topic_attn_size = self.conf.num_units
            k = tf.get_variable(
                "topic_proj",
                [1, 1, self.conf.topic_embedding_size, topic_attn_size])
            topic_attn_states = nn_ops.conv2d(
                tf.expand_dims(self.emb_enc_topics, 2), k, [1, 1, 1, 1],
                "SAME")
            topic_attn_states = tf.squeeze(topic_attn_states, axis=2)

            graphlg.info("Creating decoder cell...")
            with variable_scope.variable_scope("decoder",
                                               dtype=dtype) as scope:
                cell = CreateMultiRNNCell(conf.cell_model, attn_size,
                                          conf.num_layers,
                                          conf.output_keep_prob)
                # topic
                if not self.for_deploy:
                    graphlg.info(
                        "Embedding decoder inps, tars and tar weights...")
                    t_major_dec_inps = tf.transpose(self.dec_inps)
                    t_major_tars = tf.slice(t_major_dec_inps, [1, 0],
                                            [conf.output_max_len + 1, -1])
                    t_major_dec_inps = tf.slice(t_major_dec_inps, [0, 0],
                                                [conf.output_max_len + 1, -1])
                    t_major_tar_wgts = tf.cumsum(tf.one_hot(
                        self.dec_lens - 1, conf.output_max_len + 1, axis=0),
                                                 axis=0,
                                                 reverse=True)
                    with ops.device("/cpu:0"):
                        emb_dec_inps = embedding_lookup_unique(
                            self.embedding, t_major_dec_inps)

                    hp_train = helper.ScheduledEmbeddingTrainingHelper(
                        inputs=emb_dec_inps,
                        sequence_length=self.enc_lens,
                        embedding=self.embedding,
                        sampling_probability=0.0,
                        out_proj=self.out_proj,
                        except_ids=None,
                        time_major=True)

                    output_layer = None
                    my_decoder = AttnTopicDecoder(
                        cell=cell,
                        helper=hp_train,
                        initial_state=initial_state,
                        attn_states=attn_states,
                        attn_size=attn_size,
                        topic_attn_states=topic_attn_states,
                        topic_attn_size=topic_attn_size,
                        output_layer=output_layer)
                    t_major_cell_outs, final_state = decoder.dynamic_decode(
                        decoder=my_decoder,
                        output_time_major=True,
                        maximum_iterations=conf.output_max_len + 1,
                        scope=scope)
                    t_major_outs = t_major_cell_outs.rnn_output

                    # Branch 1 for debugging, doesn't have to be called
                    self.outputs = tf.transpose(t_major_outs, perm=[1, 0, 2])
                    L = tf.shape(self.outputs)[1]
                    w, b = self.out_proj
                    self.outputs = tf.reshape(self.outputs,
                                              [-1, int(w.shape[0])])
                    self.outputs = tf.matmul(self.outputs, w) + b

                    # For masking the except_ids when debuging
                    #m = tf.shape(self.outputs)[0]
                    #self.mask = tf.zeros([m, int(w.shape[1])])
                    #for i in [3]:
                    #    self.mask = self.mask + tf.one_hot(indices=tf.ones([m], dtype=tf.int32) * i, on_value=100.0, depth=int(w.shape[1]))
                    #self.outputs = self.outputs - self.mask

                    self.outputs = tf.argmax(self.outputs, axis=1)
                    self.outputs = tf.reshape(self.outputs, [-1, L])
                    self.outputs = self.out_table.lookup(
                        tf.cast(self.outputs, tf.int64))

                    # Branch 2 for loss
                    self.loss = dyn_sequence_loss(self.conf, t_major_outs,
                                                  self.out_proj, t_major_tars,
                                                  t_major_tar_wgts)
                    self.summary = tf.summary.scalar("%s/loss" % self.name,
                                                     self.loss)

                    # backpropagation
                    self.build_backprop(self.loss, conf, dtype)

                    #saver
                    self.trainable_params.extend(tf.trainable_variables() +
                                                 [self.topic_embedding])
                    need_to_save = self.global_params + self.trainable_params + self.optimizer_params + tf.get_default_graph(
                    ).get_collection("saveable_objects") + [
                        self.topic_embedding
                    ]
                    self.saver = tf.train.Saver(need_to_save,
                                                max_to_keep=conf.max_to_keep)
                else:
                    hp_infer = helper.GreedyEmbeddingHelper(
                        embedding=self.embedding,
                        start_tokens=tf.ones(shape=[batch_size],
                                             dtype=tf.int32),
                        end_token=EOS_ID,
                        out_proj=self.out_proj)

                    output_layer = None  #layers_core.Dense(self.conf.outproj_from_size, use_bias=True)
                    my_decoder = AttnTopicDecoder(
                        cell=cell,
                        helper=hp_infer,
                        initial_state=initial_state,
                        attn_states=attn_states,
                        attn_size=attn_size,
                        topic_attn_states=topic_attn_states,
                        topic_attn_size=topic_attn_size,
                        output_layer=output_layer)
                    cell_outs, final_state = decoder.dynamic_decode(
                        decoder=my_decoder, scope=scope, maximum_iterations=40)
                    self.outputs = cell_outs.sample_id
                    #lookup
                    self.outputs = self.out_table.lookup(
                        tf.cast(self.outputs, tf.int64))

                    #saver
                    self.trainable_params.extend(tf.trainable_variables())
                    self.saver = tf.train.Saver(max_to_keep=conf.max_to_keep)

                    # Exporter for serving
                    self.model_exporter = exporter.Exporter(self.saver)
                    inputs = {
                        "enc_inps": self.enc_str_inps,
                        "enc_lens": self.enc_lens
                    }
                    outputs = {"out": self.outputs}
                    self.model_exporter.init(
                        tf.get_default_graph().as_graph_def(),
                        named_graph_signatures={
                            "inputs": exporter.generic_signature(inputs),
                            "outputs": exporter.generic_signature(outputs)
                        })
                    graphlg.info("Graph done")
                    graphlg.info("")

                self.dec_states = final_state
    def inference(self, inputs, train=True):
        config = self.config

        # extract character representations from embedding
        with tf.variable_scope(
                'embedding',
                initializer=tf.contrib.layers.xavier_initializer()):
            print("\n\nembedding: shape: %s" % str(
                (config.vocab_size, config.embed_dim)))
            embedding = tf.get_variable('embedding',
                                        shape=(config.vocab_size,
                                               config.embed_dim),
                                        dtype=tf.float32)
            embedded_inputs = tf.nn.embedding_lookup(embedding, inputs['text'])

        # extract speaker embedding if multi-speaker
        with tf.variable_scope('speaker'):
            if config.num_speakers > 1:
                speaker_embed = tf.get_variable(
                    'speaker_embed',
                    shape=(config.num_speakers, config.speaker_embed_dim),
                    dtype=tf.float32)
                speaker_embed = \
                        tf.nn.embedding_lookup(speaker_embed, inputs['speaker'])
            else:
                speaker_embed = None

        # process text input with CBHG module
        with tf.variable_scope('encoder'):
            print("\n\nencoder: inputs: %s" % str(inputs))
            pre_out = self.pre_net(embedded_inputs,
                                   dropout=config.char_dropout_prob,
                                   train=train)
            tf.summary.histogram('pre_net_out', pre_out)

            encoded = ops.CBHG(pre_out,
                               speaker_embed,
                               K=16,
                               c=[128, 128, 128],
                               gru_units=128)

        # pass through attention based decoder
        with tf.variable_scope('decoder'):
            print("\n\ndecoder: inputs: %s\n\n" % str(inputs))
            dec = self.create_decoder(encoded, inputs, speaker_embed, train)
            (seq2seq_output, _), attention_state, _ = \
                    decoder.dynamic_decode(dec, maximum_iterations=config.max_decode_iter)
            self.alignments = tf.transpose(
                attention_state.alignment_history.stack(), [1, 0, 2])
            print("seq2seq_output: %s" % str(seq2seq_output))
            """
            if not tf.is_nan(seq2seq_output):
                tf.summary.histogram('seq2seq_output', seq2seq_output)
            else:
                print("seq2seq_output is NaN!")
            """
            tf.summary.histogram('seq2seq_output', seq2seq_output)

        # use second CBHG module to process mel features into linear spectogram
        with tf.variable_scope('post-process'):
            # reshape to account for r value
            post_input = tf.reshape(
                seq2seq_output,
                (tf.shape(seq2seq_output)[0], -1, config.mel_features))

            output = ops.CBHG(post_input, K=8, c=[128, 256, 80], gru_units=128)
            output = tf.layers.dense(output, units=config.fft_size)

            # reshape back to r frame representation
            output = tf.reshape(
                output, (tf.shape(output)[0], -1, config.fft_size * config.r))
            tf.summary.histogram('output', output)

        return seq2seq_output, output
  def _testWithAttention(self,
                         create_attention_mechanism,
                         expected_final_outputs,
                         expected_final_state,
                         attention_mechanism_depth=3):
    encoder_sequence_length = [3, 2, 3, 1, 0]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9
    attention_depth = 6

    decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                     input_depth).astype(np.float32)
    encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                      encoder_output_depth).astype(np.float32)

    attention_mechanism = create_attention_mechanism(
        num_units=attention_mechanism_depth,
        memory=encoder_outputs,
        memory_sequence_length=encoder_sequence_length)

    with self.test_session() as sess:
      with vs.variable_scope(
          "root",
          initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
        cell = core_rnn_cell.LSTMCell(cell_depth)
        cell = wrapper.DynamicAttentionWrapper(
            cell, attention_mechanism, attention_size=attention_depth)
        helper = helper_py.TrainingHelper(decoder_inputs,
                                          decoder_sequence_length)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))

        final_outputs, final_state = decoder.dynamic_decode(my_decoder)

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(
          isinstance(final_state, wrapper.DynamicAttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))

      self.assertEqual((batch_size, None, attention_depth),
                       tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual((batch_size, None),
                       tuple(final_outputs.sample_id.get_shape().as_list()))

      self.assertEqual((batch_size, attention_depth),
                       tuple(final_state.attention.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.c.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.h.get_shape().as_list()))

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "final_outputs": final_outputs,
          "final_state": final_state
      })

      nest.map_structure(self.assertAllClose, expected_final_outputs,
                         sess_results["final_outputs"])
      nest.map_structure(self.assertAllClose, expected_final_state,
                         sess_results["final_state"])
Ejemplo n.º 23
0
def impl(features, mode, hp):
    contexts = features[
        'contexts']  # batch_size,max_con_length(with query),max_sen_length
    context_utterance_length = features[
        'context_utterance_length']  # batch_size,max_con_length
    context_length = features['context_length']  # batch_size
    if mode == modekeys.TRAIN or mode == modekeys.EVAL:
        response_in = features['response_in']  # batch,max_res_sen
        response_out = features['response_out']  # batch,max_res_sen
        response_mask = features[
            'response_mask']  # batch,max_res_sen, tf.float32
        batch_size = hp.batch_size
    else:
        batch_size = context_utterance_length.shape[0].value

    with tf.variable_scope('embedding_layer', reuse=tf.AUTO_REUSE) as vs:
        embedding_w = get_embedding_matrix(hp.word_dim, mode, hp.vocab_size,
                                           random_seed, hp.word_embed_path,
                                           hp.vocab_path)
        contexts = tf.nn.embedding_lookup(embedding_w, contexts,
                                          'context_embedding')
        if mode == modekeys.TRAIN or mode == modekeys.EVAL:
            response_in = tf.nn.embedding_lookup(embedding_w, response_in,
                                                 'response_in_embedding')

    with tf.variable_scope('utterance_encoding_layer',
                           reuse=tf.AUTO_REUSE) as vs:
        kernel_initializer = tf.random_normal_initializer(mean=0.0,
                                                          stddev=0.1,
                                                          seed=random_seed + 1)
        bias_initializer = tf.zeros_initializer()
        fw_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.word_rnn_num_units,
                                         kernel_initializer=kernel_initializer,
                                         bias_initializer=bias_initializer)
        kernel_initializer = tf.random_normal_initializer(mean=0.0,
                                                          stddev=0.1,
                                                          seed=random_seed - 1)
        bias_initializer = tf.zeros_initializer()
        bw_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.word_rnn_num_units,
                                         kernel_initializer=kernel_initializer,
                                         bias_initializer=bias_initializer)

        context_t = tf.transpose(contexts, perm=[
            1, 0, 2, 3
        ])  # max_con_length(with query),batch_size,max_sen_length
        context_utterance_length_t = tf.transpose(
            context_utterance_length, perm=[1,
                                            0])  # max_con_length, batch_size
        a = tf.split(context_t, hp.max_context_length,
                     axis=0)  # 1,batch_size,max_sen_length
        b = tf.split(context_utterance_length_t, hp.max_context_length,
                     axis=0)  # 1,batch_size

        utterance_encodings = []
        for utterance, length in zip(a, b):
            utterance = tf.squeeze(utterance, axis=0)
            length = tf.squeeze(length, axis=0)
            utterance_hidden_states, _ = tf.nn.bidirectional_dynamic_rnn(
                fw_cell,
                bw_cell,
                utterance,
                sequence_length=length,
                initial_state_fw=fw_cell.zero_state(batch_size, tf.float32),
                initial_state_bw=bw_cell.zero_state(batch_size, tf.float32))
            utterance_encoding = tf.concat(utterance_hidden_states, axis=2)
            utterance_encodings.append(
                tf.expand_dims(utterance_encoding, axis=0))

        utterance_encodings = tf.concat(
            utterance_encodings,
            axis=0)  # max_con_length,batch_size,max_sen,2*word_rnn_num_units

    with tf.variable_scope('hierarchical_attention_layer',
                           reuse=tf.AUTO_REUSE) as vs:
        if mode == modekeys.PREDICT and hp.beam_width != 0:
            utterance_encodings = tf.transpose(utterance_encodings,
                                               perm=[1, 0, 2, 3])
            utterance_encodings = tile_batch(utterance_encodings,
                                             multiplier=hp.beam_width)
            utterance_encodings = tf.transpose(utterance_encodings,
                                               perm=[1, 0, 2, 3])

            context_utterance_length_t = tf.transpose(
                context_utterance_length_t, perm=[1, 0])
            context_utterance_length_t = tile_batch(context_utterance_length_t,
                                                    multiplier=hp.beam_width)
            context_utterance_length_t = tf.transpose(
                context_utterance_length_t, perm=[1, 0])

            context_length = tile_batch(context_length,
                                        multiplier=hp.beam_width)

        attention_mechanism = ContextAttentionMechanism(
            context_attn_units=hp.context_attn_units,
            utte_attn_units=hp.utte_attn_units,
            context=utterance_encodings,
            context_utterance_length=context_utterance_length_t,
            max_context_length=hp.max_context_length,
            context_rnn_num_units=hp.context_rnn_num_units,
            context_actual_length=context_length)

    with tf.variable_scope('decoder_layer', reuse=tf.AUTO_REUSE) as vs:
        kernel_initializer = tf.random_normal_initializer(mean=0.0,
                                                          stddev=0.1,
                                                          seed=random_seed + 3)
        bias_initializer = tf.zeros_initializer()
        decoder_cell = tf.nn.rnn_cell.GRUCell(
            num_units=hp.decoder_rnn_num_units,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer)
        attn_cell = AttentionWrapper(
            decoder_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=None,
            output_attention=False)  # output_attention should be False
        output_layer = layers_core.Dense(
            units=hp.vocab_size, activation=None,
            use_bias=False)  # should use no activation and no bias

        if mode == modekeys.TRAIN:
            sequence_length = tf.constant(value=hp.max_sentence_length,
                                          dtype=tf.int32,
                                          shape=[batch_size])
            helper = TrainingHelper(inputs=response_in,
                                    sequence_length=sequence_length)
            decoder = BasicDecoder(cell=attn_cell,
                                   helper=helper,
                                   initial_state=attn_cell.zero_state(
                                       batch_size, tf.float32),
                                   output_layer=output_layer)
            final_outputs, final_state, final_sequence_lengths = dynamic_decode(
                decoder=decoder,
                impute_finished=True,
                parallel_iterations=32,
                swap_memory=True)
            logits = final_outputs.rnn_output
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=response_out, logits=logits)
            cross_entropy = tf.multiply(cross_entropy, response_mask)
            cross_entropy = tf.reduce_sum(cross_entropy, axis=1)
            loss = tf.reduce_mean(cross_entropy)
            l2_norm = hp.lambda_l2 * tf.add_n([
                tf.nn.l2_loss(var)
                for var in tf.trainable_variables() if 'bias' not in var.name
            ])
            loss = loss + l2_norm

            debug_tensors = []
            return loss, debug_tensors
        elif mode == modekeys.EVAL:
            sequence_length = tf.constant(value=hp.max_sentence_length,
                                          dtype=tf.int32,
                                          shape=[batch_size])
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=response_in, sequence_length=sequence_length)
            decoder = BasicDecoder(cell=attn_cell,
                                   helper=helper,
                                   initial_state=attn_cell.zero_state(
                                       batch_size, tf.float32),
                                   output_layer=output_layer)
            final_outputs, final_state, final_sequence_lengths = dynamic_decode(
                decoder=decoder,
                impute_finished=True,
                parallel_iterations=32,
                swap_memory=True)
            logits = final_outputs.rnn_output
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=response_out, logits=logits)
            cross_entropy = tf.reduce_mean(cross_entropy * response_mask)
            ppl = tf.exp(cross_entropy)
            return ppl
        elif mode == modekeys.PREDICT:
            if hp.beam_width == 0:
                helper = GreedyEmbeddingHelper(embedding=embedding_w,
                                               start_tokens=tf.constant(
                                                   1,
                                                   tf.int32,
                                                   shape=[batch_size]),
                                               end_token=2)
                initial_state = attn_cell.zero_state(batch_size=batch_size,
                                                     dtype=tf.float32)
                decoder = BasicDecoder(cell=attn_cell,
                                       helper=helper,
                                       initial_state=initial_state,
                                       output_layer=output_layer)
                final_outputs, final_state, final_sequence_lengths = dynamic_decode(
                    decoder, maximum_iterations=hp.max_sentence_length)
                results = {}
                results['response_ids'] = final_outputs.sample_id
                results['response_lens'] = final_sequence_lengths
                return results
            else:
                decoder_initial_state = attn_cell.zero_state(
                    batch_size=batch_size * hp.beam_width, dtype=tf.float32)
                decoder = BeamSearchDecoder(
                    cell=attn_cell,
                    embedding=embedding_w,
                    start_tokens=tf.constant(1, tf.int32, shape=[batch_size]),
                    end_token=2,
                    initial_state=decoder_initial_state,
                    beam_width=hp.beam_width,
                    output_layer=output_layer)
                final_outputs, final_state, final_sequence_lengths = dynamic_decode(
                    decoder,
                    impute_finished=False,
                    maximum_iterations=hp.max_sentence_length)

                final_outputs = final_outputs.predicted_ids  # b,s,beam_width
                final_outputs = tf.transpose(final_outputs,
                                             perm=[0, 2, 1])  # b,beam_width,s
                # predicted_length = final_state.lengths #b,s
                predicted_length = None

                results = {}
                results['response_ids'] = final_outputs
                results['response_lens'] = None
                return results
Ejemplo n.º 24
0
  def _testWithAttention(self,
                         create_attention_mechanism,
                         expected_final_output,
                         expected_final_state,
                         attention_mechanism_depth=3,
                         alignment_history=False,
                         expected_final_alignment_history=None,
                         attention_layer_size=6,
                         name=""):
    encoder_sequence_length = [3, 2, 3, 1, 0]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9

    if attention_layer_size is not None:
      attention_depth = attention_layer_size
    else:
      attention_depth = encoder_output_depth

    decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                     input_depth).astype(np.float32)
    encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                      encoder_output_depth).astype(np.float32)

    attention_mechanism = create_attention_mechanism(
        num_units=attention_mechanism_depth,
        memory=encoder_outputs,
        memory_sequence_length=encoder_sequence_length)

    with self.test_session(use_gpu=True) as sess:
      with vs.variable_scope(
          "root",
          initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
        cell = core_rnn_cell.LSTMCell(cell_depth)
        cell = wrapper.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=attention_layer_size,
            alignment_history=alignment_history)
        helper = helper_py.TrainingHelper(decoder_inputs,
                                          decoder_sequence_length)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))

        final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))

      self.assertEqual((batch_size, None, attention_depth),
                       tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual((batch_size, None),
                       tuple(final_outputs.sample_id.get_shape().as_list()))

      self.assertEqual((batch_size, attention_depth),
                       tuple(final_state.attention.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.c.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.h.get_shape().as_list()))

      if alignment_history:
        state_alignment_history = final_state.alignment_history.stack()
        # Remove the history from final_state for purposes of the
        # remainder of the tests.
        final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
        self.assertEqual((None, batch_size, encoder_max_time),
                         tuple(state_alignment_history.get_shape().as_list()))
      else:
        state_alignment_history = ()

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "final_outputs": final_outputs,
          "final_state": final_state,
          "state_alignment_history": state_alignment_history,
      })

      print("Copy/paste (%s)\nexpected_final_output = " % name,
            sess_results["final_outputs"])
      sys.stdout.flush()
      print("Copy/paste (%s)\nexpected_final_state = " % name,
            sess_results["final_state"])
      sys.stdout.flush()
      print("Copy/paste (%s)\nexpected_final_alignment_history = " % name,
            np.asarray(sess_results["state_alignment_history"]))
      sys.stdout.flush()
      nest.map_structure(self.assertAllClose, expected_final_output,
                         sess_results["final_outputs"])
      nest.map_structure(self.assertAllClose, expected_final_state,
                         sess_results["final_state"])
      if alignment_history:  # by default, the wrapper emits attention as output
        self.assertAllClose(
            # outputs are batch major but the stacked TensorArray is time major
            sess_results["state_alignment_history"],
            expected_final_alignment_history)
    def _testWithAttention(self,
                           create_attention_mechanism,
                           expected_final_outputs,
                           expected_final_state,
                           attention_mechanism_depth=3):
        encoder_sequence_length = [3, 2, 3, 1, 0]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        attention_depth = 6

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanism = create_attention_mechanism(
            num_units=attention_mechanism_depth,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)

        with self.test_session() as sess:
            with vs.variable_scope(
                    "root",
                    initializer=init_ops.random_normal_initializer(stddev=0.01,
                                                                   seed=3)):
                cell = core_rnn_cell.LSTMCell(cell_depth)
                cell = wrapper.DynamicAttentionWrapper(
                    cell, attention_mechanism, attention_size=attention_depth)
                helper = helper_py.TrainingHelper(decoder_inputs,
                                                  decoder_sequence_length)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))

                final_outputs, final_state = decoder.dynamic_decode(my_decoder)

            self.assertTrue(
                isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
            self.assertTrue(
                isinstance(final_state, wrapper.DynamicAttentionWrapperState))
            self.assertTrue(
                isinstance(final_state.cell_state,
                           core_rnn_cell.LSTMStateTuple))

            self.assertEqual(
                (batch_size, None, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, None),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.c.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.h.get_shape().as_list()))

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_outputs": final_outputs,
                "final_state": final_state
            })

            nest.map_structure(self.assertAllClose, expected_final_outputs,
                               sess_results["final_outputs"])
            nest.map_structure(self.assertAllClose, expected_final_state,
                               sess_results["final_state"])
    def _testDynamicDecodeRNN(self, time_major, has_attention):
        encoder_sequence_length = np.array([3, 2, 3, 1, 1])
        decoder_sequence_length = np.array([2, 0, 1, 2, 3])
        batch_size = 5
        decoder_max_time = 4
        input_depth = 7
        cell_depth = 9
        attention_depth = 6
        vocab_size = 20
        end_token = vocab_size - 1
        start_token = 0
        embedding_dim = 50
        max_out = max(decoder_sequence_length)
        output_layer = layers_core.Dense(vocab_size,
                                         use_bias=True,
                                         activation=None)
        beam_width = 3

        with self.test_session() as sess:
            batch_size_tensor = constant_op.constant(batch_size)
            embedding = np.random.randn(vocab_size,
                                        embedding_dim).astype(np.float32)
            cell = rnn_cell.LSTMCell(cell_depth)
            initial_state = cell.zero_state(batch_size, dtypes.float32)
            if has_attention:
                inputs = array_ops.placeholder_with_default(
                    np.random.randn(batch_size, decoder_max_time,
                                    input_depth).astype(np.float32),
                    shape=(None, None, input_depth))
                tiled_inputs = beam_search_decoder.tile_batch(
                    inputs, multiplier=beam_width)
                tiled_sequence_length = beam_search_decoder.tile_batch(
                    encoder_sequence_length, multiplier=beam_width)
                attention_mechanism = attention_wrapper.BahdanauAttention(
                    num_units=attention_depth,
                    memory=tiled_inputs,
                    memory_sequence_length=tiled_sequence_length)
                initial_state = beam_search_decoder.tile_batch(
                    initial_state, multiplier=beam_width)
                cell = attention_wrapper.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=attention_depth,
                    alignment_history=False)
            cell_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size_tensor *
                                         beam_width)
            if has_attention:
                cell_state = cell_state.clone(cell_state=initial_state)
            bsd = beam_search_decoder.BeamSearchDecoder(
                cell=cell,
                embedding=embedding,
                start_tokens=array_ops.fill([batch_size_tensor], start_token),
                end_token=end_token,
                initial_state=cell_state,
                beam_width=beam_width,
                output_layer=output_layer,
                length_penalty_weight=0.0)

            final_outputs, final_state, final_sequence_lengths = (
                decoder.dynamic_decode(bsd,
                                       output_time_major=time_major,
                                       maximum_iterations=max_out))

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertTrue(
                isinstance(final_outputs,
                           beam_search_decoder.FinalBeamSearchDecoderOutput))
            self.assertTrue(
                isinstance(final_state,
                           beam_search_decoder.BeamSearchDecoderState))

            beam_search_decoder_output = final_outputs.beam_search_decoder_output
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(beam_search_decoder_output.scores.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(final_outputs.predicted_ids.get_shape().as_list()))

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                'final_outputs':
                final_outputs,
                'final_state':
                final_state,
                'final_sequence_lengths':
                final_sequence_lengths
            })

            max_sequence_length = np.max(
                sess_results['final_sequence_lengths'])

            # A smoke test
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                scores.shape)
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                predicted_ids.shape)
    def _testWithAttention(self,
                           create_attention_mechanism,
                           expected_final_output,
                           expected_final_state,
                           attention_mechanism_depth=3,
                           alignment_history=False,
                           expected_final_alignment_history=None,
                           name=""):
        encoder_sequence_length = [3, 2, 3, 1, 0]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        attention_depth = 6

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanism = create_attention_mechanism(
            num_units=attention_mechanism_depth,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)

        with self.test_session(use_gpu=True) as sess:
            with vs.variable_scope(
                    "root",
                    initializer=init_ops.random_normal_initializer(stddev=0.01,
                                                                   seed=3)):
                cell = core_rnn_cell.LSTMCell(cell_depth)
                cell = wrapper.AttentionWrapper(
                    cell,
                    attention_mechanism,
                    attention_size=attention_depth,
                    alignment_history=alignment_history)
                helper = helper_py.TrainingHelper(decoder_inputs,
                                                  decoder_sequence_length)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))

                final_outputs, final_state = decoder.dynamic_decode(my_decoder)

            self.assertTrue(
                isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
            self.assertTrue(
                isinstance(final_state, wrapper.AttentionWrapperState))
            self.assertTrue(
                isinstance(final_state.cell_state,
                           core_rnn_cell.LSTMStateTuple))

            self.assertEqual(
                (batch_size, None, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, None),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.c.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.h.get_shape().as_list()))

            if alignment_history:
                state_alignment_history = final_state.alignment_history.stack()
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
                self.assertEqual(
                    (None, batch_size, encoder_max_time),
                    tuple(state_alignment_history.get_shape().as_list()))
            else:
                state_alignment_history = ()

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_outputs":
                final_outputs,
                "final_state":
                final_state,
                "state_alignment_history":
                state_alignment_history,
            })

            print("Copy/paste (%s)\nexpected_final_output = " % name,
                  sess_results["final_outputs"])
            sys.stdout.flush()
            print("Copy/paste (%s)\nexpected_final_state = " % name,
                  sess_results["final_state"])
            sys.stdout.flush()
            print(
                "Copy/paste (%s)\nexpected_final_alignment_history = " % name,
                sess_results["state_alignment_history"])
            sys.stdout.flush()
            nest.map_structure(self.assertAllClose, expected_final_output,
                               sess_results["final_outputs"])
            nest.map_structure(self.assertAllClose, expected_final_state,
                               sess_results["final_state"])
            if alignment_history:  # by default, the wrapper emits attention as output
                self.assertAllClose(
                    # outputs are batch major but the stacked TensorArray is time major
                    sess_results["state_alignment_history"],
                    expected_final_alignment_history)