def testLuongScaledDType(self):
    # Test case for GitHub issue 18099
    for dtype in [np.float16, np.float32, np.float64]:
      num_units = 128
      encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
      encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
      decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
      batch_size = 64
      attention_mechanism = wrapper.LuongAttention(
          num_units=num_units,
          memory=encoder_outputs,
          memory_sequence_length=encoder_sequence_length,
          scale=True,
          dtype=dtype,
      )
      cell = rnn_cell.LSTMCell(num_units)
      cell = wrapper.AttentionWrapper(cell, attention_mechanism)

      helper = helper_py.TrainingHelper(decoder_inputs,
                                        decoder_sequence_length)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtype, batch_size=batch_size))

      final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual(final_outputs.rnn_output.dtype, dtype)
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
    def create_decoder(self, encoded, inputs, speaker_embed, train=True):
        config = self.config
        attention_mech = wrapper.BahdanauAttention(
            config.attention_units,
            encoded,
            memory_sequence_length=inputs['text_length'])

        inner_cell = [GRUCell(config.decoder_units) for _ in range(3)]

        decoder_cell = OutputProjectionWrapper(
            InputProjectionWrapper(ResidualWrapper(MultiRNNCell(inner_cell)),
                                   config.decoder_units),
            config.mel_features * config.r)

        # feed in rth frame at each time step
        decoder_frame_input = \
            lambda inputs, attention: tf.concat(
                    [self.pre_net(tf.slice(inputs,
                        [0, (config.r - 1)*config.mel_features], [-1, -1]),
                        dropout=config.audio_dropout_prob,
                        train=train),
                    attention]
                , -1)

        cell = wrapper.AttentionWrapper(
            decoder_cell,
            attention_mech,
            attention_layer_size=config.attention_units,
            cell_input_fn=decoder_frame_input,
            alignment_history=True,
            output_attention=False)

        if train:
            if config.scheduled_sample:
                print("if train if config.scheduled_sample: %s" % str(
                    (inputs['mel'], inputs['speech_length'],
                     config.scheduled_sample)))
                decoder_helper = helper.ScheduledOutputTrainingHelper(
                    inputs['mel'], inputs['speech_length'],
                    config.scheduled_sample)
            else:
                decoder_helper = helper.TrainingHelper(inputs['mel'],
                                                       inputs['speech_length'])
        else:
            decoder_helper = ops.InferenceHelper(
                tf.shape(inputs['text'])[0], config.mel_features * config.r)

        initial_state = cell.zero_state(dtype=tf.float32,
                                        batch_size=tf.shape(inputs['text'])[0])

        #if speaker_embed is not None:
        #initial_state.attention = tf.layers.dense(speaker_embed, config.attention_units)

        dec = basic_decoder.BasicDecoder(cell, decoder_helper, initial_state)

        return dec
예제 #3
0
    def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
            self, use_sequence_length):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = 10
        max_out = max(sequence_length)

        with self.test_session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)

            cell = core_rnn_cell.LSTMCell(cell_depth)
            zero_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size)
            helper = helper_py.TrainingHelper(inputs, sequence_length)
            my_decoder = basic_decoder.BasicDecoder(cell=cell,
                                                    helper=helper,
                                                    initial_state=zero_state)

            # Match the variable scope of dynamic_rnn below so we end up
            # using the same variables
            with vs.variable_scope("root") as scope:
                final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode(
                    my_decoder,
                    # impute_finished=True ensures outputs and final state
                    # match those of dynamic_rnn called with sequence_length not None
                    impute_finished=use_sequence_length,
                    scope=scope)

            with vs.variable_scope(scope, reuse=True) as scope:
                final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn(
                    cell,
                    inputs,
                    sequence_length=sequence_length
                    if use_sequence_length else None,
                    initial_state=zero_state,
                    scope=scope)

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_decoder_outputs": final_decoder_outputs,
                "final_decoder_state": final_decoder_state,
                "final_rnn_outputs": final_rnn_outputs,
                "final_rnn_state": final_rnn_state
            })

            # Decoder only runs out to max_out; ensure values are identical
            # to dynamic_rnn, which also zeros out outputs and passes along state.
            self.assertAllClose(
                sess_results["final_decoder_outputs"].rnn_output,
                sess_results["final_rnn_outputs"][:, 0:max_out, :])
            if use_sequence_length:
                self.assertAllClose(sess_results["final_decoder_state"],
                                    sess_results["final_rnn_state"])
예제 #4
0
  def testStepWithScheduledEmbeddingTrainingHelper(self):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    vocabulary_size = 10

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(
          batch_size, max_time, input_depth).astype(np.float32)
      embeddings = np.random.randn(
          vocabulary_size, input_depth).astype(np.float32)
      half = constant_op.constant(0.5)
      cell = rnn_cell.LSTMCell(vocabulary_size)
      helper = helper_py.ScheduledEmbeddingTrainingHelper(
          inputs=inputs,
          sequence_length=sequence_length,
          embedding=embeddings,
          sampling_probability=half,
          time_major=False)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(vocabulary_size,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, vocabulary_size),
                       step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       first_state[0].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       first_state[1].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       step_state[0].get_shape())
      self.assertEqual((batch_size, vocabulary_size),
                       step_state[1].get_shape())
      self.assertEqual((batch_size, input_depth),
                       step_next_inputs.get_shape())

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])
      sample_ids = sess_results["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      batch_where_not_sampling = np.where(sample_ids == -1)
      batch_where_sampling = np.where(sample_ids > -1)
      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_sampling],
          embeddings[sample_ids[batch_where_sampling]])
      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_not_sampling],
          np.squeeze(inputs[batch_where_not_sampling, 1]))
예제 #5
0
  def testStepWithSampleEmbeddingHelper(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size  # cell's logits must match vocabulary size
    input_depth = 10
    np.random.seed(0)
    start_tokens = np.random.randint(0, vocabulary_size, size=batch_size)
    end_token = 1

    with self.session(use_gpu=True) as sess:
      with variable_scope.variable_scope(
          "testStepWithSampleEmbeddingHelper",
          initializer=init_ops.constant_initializer(0.01)):
        embeddings = np.random.randn(vocabulary_size,
                                     input_depth).astype(np.float32)
        cell = rnn_cell.LSTMCell(vocabulary_size)
        helper = helper_py.SampleEmbeddingHelper(embeddings, start_tokens,
                                                 end_token, seed=0)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))
        output_size = my_decoder.output_size
        output_dtype = my_decoder.output_dtype
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(cell_depth,
                                             tensor_shape.TensorShape([])),
            output_size)
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
            output_dtype)

        (first_finished, first_inputs, first_state) = my_decoder.initialize()
        (step_outputs, step_state, step_next_inputs,
         step_finished) = my_decoder.step(
             constant_op.constant(0), first_inputs, first_state)
        batch_size_t = my_decoder.batch_size

        self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(
            isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
        self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
        self.assertEqual((batch_size,), step_outputs[1].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

        sess.run(variables.global_variables_initializer())
        sess_results = sess.run({
            "batch_size": batch_size_t,
            "first_finished": first_finished,
            "first_inputs": first_inputs,
            "first_state": first_state,
            "step_outputs": step_outputs,
            "step_state": step_state,
            "step_next_inputs": step_next_inputs,
            "step_finished": step_finished
        })

        sample_ids = sess_results["step_outputs"].sample_id
        self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
        expected_step_finished = (sample_ids == end_token)
        expected_step_next_inputs = embeddings[sample_ids]
        self.assertAllEqual(expected_step_finished,
                            sess_results["step_finished"])
        self.assertAllEqual(expected_step_next_inputs,
                            sess_results["step_next_inputs"])
    def _testWithAttention(self,
                           create_attention_mechanism,
                           expected_final_output,
                           expected_final_state,
                           attention_mechanism_depth=3,
                           alignment_history=False,
                           expected_final_alignment_history=None,
                           attention_layer_size=6,
                           name=''):
        encoder_sequence_length = [3, 2, 3, 1, 1]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9

        if attention_layer_size is not None:
            attention_depth = attention_layer_size
        else:
            attention_depth = encoder_output_depth

        decoder_inputs = array_ops.placeholder_with_default(
            np.random.randn(batch_size, decoder_max_time,
                            input_depth).astype(np.float32),
            shape=(None, None, input_depth))
        encoder_outputs = array_ops.placeholder_with_default(
            np.random.randn(batch_size, encoder_max_time,
                            encoder_output_depth).astype(np.float32),
            shape=(None, None, encoder_output_depth))

        attention_mechanism = create_attention_mechanism(
            num_units=attention_mechanism_depth,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)

        with self.test_session(use_gpu=True) as sess:
            with vs.variable_scope(
                    'root',
                    initializer=init_ops.random_normal_initializer(stddev=0.01,
                                                                   seed=3)):
                cell = rnn_cell.LSTMCell(cell_depth)
                cell = wrapper.AttentionWrapper(
                    cell,
                    attention_mechanism,
                    attention_layer_size=attention_layer_size,
                    alignment_history=alignment_history)
                helper = helper_py.TrainingHelper(decoder_inputs,
                                                  decoder_sequence_length)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))

                final_outputs, final_state, _ = decoder.dynamic_decode(
                    my_decoder)

            self.assertTrue(
                isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
            self.assertTrue(
                isinstance(final_state, wrapper.AttentionWrapperState))
            self.assertTrue(
                isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))

            self.assertEqual(
                (batch_size, None, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, None),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.c.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.h.get_shape().as_list()))

            if alignment_history:
                state_alignment_history = final_state.alignment_history.stack()
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
                self.assertEqual(
                    (None, batch_size, None),
                    tuple(state_alignment_history.get_shape().as_list()))
            else:
                state_alignment_history = ()

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                'final_outputs':
                final_outputs,
                'final_state':
                final_state,
                'state_alignment_history':
                state_alignment_history,
            })

            final_output_info = nest.map_structure(
                get_result_summary, sess_results['final_outputs'])
            final_state_info = nest.map_structure(get_result_summary,
                                                  sess_results['final_state'])
            print(name)
            print('Copy/paste:\nexpected_final_output = %s' %
                  str(final_output_info))
            print('expected_final_state = %s' % str(final_state_info))
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_output, final_output_info)
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_state, final_state_info)
            if alignment_history:  # by default, the wrapper emits attention as output
                final_alignment_history_info = nest.map_structure(
                    get_result_summary,
                    sess_results['state_alignment_history'])
                print('expected_final_alignment_history = %s' %
                      str(final_alignment_history_info))
                nest.map_structure(
                    self.assertAllCloseOrEqual,
                    # outputs are batch major but the stacked TensorArray is time major
                    expected_final_alignment_history,
                    final_alignment_history_info)
예제 #7
0
  def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):

    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10
    max_out = max(sequence_length)

    with self.session(use_gpu=True) as sess:
      if time_major:
        inputs = np.random.randn(max_time, batch_size,
                                 input_depth).astype(np.float32)
      else:
        inputs = np.random.randn(batch_size, max_time,
                                 input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      helper = helper_py.TrainingHelper(
          inputs, sequence_length, time_major=time_major)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))

      final_outputs, final_state, final_sequence_length = (
          decoder.dynamic_decode(my_decoder, output_time_major=time_major,
                                 maximum_iterations=maximum_iterations))

      def _t(shape):
        if time_major:
          return (shape[1], shape[0]) + shape[2:]
        return shape

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(isinstance(final_state, rnn_cell.LSTMStateTuple))

      self.assertEqual(
          (batch_size,),
          tuple(final_sequence_length.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None, cell_depth)),
          tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None)),
          tuple(final_outputs.sample_id.get_shape().as_list()))

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "final_outputs": final_outputs,
          "final_state": final_state,
          "final_sequence_length": final_sequence_length,
      })

      # Mostly a smoke test
      time_steps = max_out
      expected_length = sequence_length
      if maximum_iterations is not None:
        time_steps = min(max_out, maximum_iterations)
        expected_length = [min(x, maximum_iterations) for x in expected_length]
      self.assertEqual(
          _t((batch_size, time_steps, cell_depth)),
          sess_results["final_outputs"].rnn_output.shape)
      self.assertEqual(
          _t((batch_size, time_steps)),
          sess_results["final_outputs"].sample_id.shape)
      self.assertItemsEqual(expected_length,
                            sess_results["final_sequence_length"])
예제 #8
0
    def testStepWithGreedyEmbeddingHelper(self):
        batch_size = 5
        vocabulary_size = 7
        cell_depth = vocabulary_size  # cell's logits must match vocabulary size
        input_depth = 10
        start_tokens = [0] * batch_size
        end_token = 1

        with self.test_session() as sess:
            embeddings = np.random.randn(vocabulary_size,
                                         input_depth).astype(np.float32)
            cell = core_rnn_cell.LSTMCell(vocabulary_size)
            helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens,
                                                     end_token)
            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size))
            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)
            batch_size_t = my_decoder.batch_size

            self.assertTrue(
                isinstance(first_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            })

            expected_sample_ids = np.argmax(
                sess_results["step_outputs"].rnn_output, -1)
            expected_step_finished = (expected_sample_ids == end_token)
            expected_step_next_inputs = embeddings[expected_sample_ids]
            self.assertAllEqual([False, False, False, False, False],
                                sess_results["first_finished"])
            self.assertAllEqual(expected_step_finished,
                                sess_results["step_finished"])
            self.assertAllEqual(expected_sample_ids,
                                sess_results["step_outputs"].sample_id)
            self.assertAllEqual(expected_step_next_inputs,
                                sess_results["step_next_inputs"])
예제 #9
0
    def _build_decoder(self, encoder_outputs, encoder_state):
        with tf.name_scope("seq_decoder"):
            batch_size = self.batch_size
            # sequence_length = tf.fill([self.batch_size], self.num_steps)
            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                sequence_length = self.iterator.target_length
            else:
                sequence_length = self.iterator.source_length
            if (self.mode !=
                    tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1:
                batch_size = batch_size * self.beam_width
                encoder_outputs = beam_search_decoder.tile_batch(
                    encoder_outputs, multiplier=self.beam_width)
                encoder_state = nest.map_structure(
                    lambda s: beam_search_decoder.tile_batch(
                        s, self.beam_width), encoder_state)
                sequence_length = beam_search_decoder.tile_batch(
                    sequence_length, multiplier=self.beam_width)

            single_cell = single_rnn_cell(self.hparams.unit_type,
                                          self.num_units, self.dropout)
            decoder_cell = MultiRNNCell(
                [single_cell for _ in range(self.num_layers_decoder)])
            decoder_cell = InputProjectionWrapper(decoder_cell,
                                                  num_proj=self.num_units)
            attention_mechanism = create_attention_mechanism(
                self.hparams.attention_mechanism,
                self.num_units,
                memory=encoder_outputs,
                source_sequence_length=sequence_length)
            decoder_cell = wrapper.AttentionWrapper(
                decoder_cell,
                attention_mechanism,
                attention_layer_size=self.num_units,
                output_attention=True,
                alignment_history=False)

            # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다.
            initial_state = decoder_cell.zero_state(batch_size=batch_size,
                                                    dtype=tf.float32)
            embeddings_decoder = tf.get_variable(
                "embedding_decoder",
                [self.num_decoder_symbols, self.num_units],
                initializer=self.initializer,
                dtype=tf.float32)
            output_layer = Dense(units=self.num_decoder_symbols,
                                 use_bias=True,
                                 name="output_layer")

            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                decoder_inputs = tf.nn.embedding_lookup(
                    embeddings_decoder, self.iterator.target_in)
                decoder_helper = helper.TrainingHelper(
                    decoder_inputs, sequence_length=sequence_length)

                dec = basic_decoder.BasicDecoder(decoder_cell,
                                                 decoder_helper,
                                                 initial_state,
                                                 output_layer=output_layer)
                final_outputs, final_state, _ = decoder.dynamic_decode(dec)
                output_ids = final_outputs.rnn_output
                outputs = final_outputs.sample_id
            else:

                def embedding_fn(inputs):
                    return tf.nn.embedding_lookup(embeddings_decoder, inputs)

                decoding_length_factor = 2.0
                max_encoder_length = tf.reduce_max(self.iterator.source_length)
                maximum_iterations = tf.to_int32(
                    tf.round(
                        tf.to_float(max_encoder_length) *
                        decoding_length_factor))

                tgt_sos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)),
                    tf.int32)
                tgt_eos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)),
                    tf.int32)
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id

                if self.beam_width == 1:
                    decoder_helper = helper.GreedyEmbeddingHelper(
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token)
                    dec = basic_decoder.BasicDecoder(decoder_cell,
                                                     decoder_helper,
                                                     initial_state,
                                                     output_layer=output_layer)
                else:
                    dec = beam_search_decoder.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=initial_state,
                        output_layer=output_layer,
                        beam_width=self.beam_width)
                final_outputs, final_state, _ = decoder.dynamic_decode(
                    dec,
                    # swap_memory=True,
                    maximum_iterations=maximum_iterations)
                if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1:
                    output_ids = final_outputs.sample_id
                    outputs = final_outputs.rnn_output
                else:
                    output_ids = final_outputs.predicted_ids
                    outputs = final_outputs.beam_search_decoder_output.scores

            return output_ids, outputs
예제 #10
0
  def testStepWithInferenceHelperCategorical(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = array_ops.one_hot(
        np.ones(batch_size) * start_token,
        vocabulary_size)

    # The sample function samples categorically from the logits.
    sample_fn = lambda x: helper_py.categorical_sample(logits=x)
    # The next inputs are a one-hot encoding of the sampled labels.
    next_inputs_fn = (
        lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32))
    end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token)

    with self.session(use_gpu=True) as sess:
      with variable_scope.variable_scope(
          "testStepWithInferenceHelper",
          initializer=init_ops.constant_initializer(0.01)):
        cell = rnn_cell.LSTMCell(vocabulary_size)
        helper = helper_py.InferenceHelper(
            sample_fn, sample_shape=(), sample_dtype=dtypes.int32,
            start_inputs=start_inputs, end_fn=end_fn,
            next_inputs_fn=next_inputs_fn)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))
        output_size = my_decoder.output_size
        output_dtype = my_decoder.output_dtype
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(cell_depth,
                                             tensor_shape.TensorShape([])),
            output_size)
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
            output_dtype)

        (first_finished, first_inputs, first_state) = my_decoder.initialize()
        (step_outputs, step_state, step_next_inputs,
         step_finished) = my_decoder.step(
             constant_op.constant(0), first_inputs, first_state)
        batch_size_t = my_decoder.batch_size

        self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(
            isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
        self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
        self.assertEqual((batch_size,), step_outputs[1].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

        sess.run(variables.global_variables_initializer())
        sess_results = sess.run({
            "batch_size": batch_size_t,
            "first_finished": first_finished,
            "first_inputs": first_inputs,
            "first_state": first_state,
            "step_outputs": step_outputs,
            "step_state": step_state,
            "step_next_inputs": step_next_inputs,
            "step_finished": step_finished
        })

        sample_ids = sess_results["step_outputs"].sample_id
        self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
        expected_step_finished = (sample_ids == end_token)
        expected_step_next_inputs = np.zeros((batch_size, vocabulary_size))
        expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0
        self.assertAllEqual(expected_step_finished,
                            sess_results["step_finished"])
        self.assertAllEqual(expected_step_next_inputs,
                            sess_results["step_next_inputs"])
예제 #11
0
  def testStepWithTrainingHelper(self):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10

    with self.test_session() as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      cell = core_rnn_cell.LSTMCell(cell_depth)
      helper = helper_py.TrainingHelper(
          inputs, sequence_length, time_major=False)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])
      self.assertAllEqual(
          np.argmax(sess_results["step_outputs"].rnn_output, -1),
          sess_results["step_outputs"].sample_id)
    def _testWithAttention(self,
                           create_attention_mechanism,
                           expected_final_outputs,
                           expected_final_state,
                           attention_mechanism_depth=3):
        encoder_sequence_length = [3, 2, 3, 1, 0]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        attention_depth = 6

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanism = create_attention_mechanism(
            num_units=attention_mechanism_depth,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)

        with self.test_session() as sess:
            with vs.variable_scope(
                    "root",
                    initializer=init_ops.random_normal_initializer(stddev=0.01,
                                                                   seed=3)):
                cell = core_rnn_cell.LSTMCell(cell_depth)
                cell = wrapper.DynamicAttentionWrapper(
                    cell, attention_mechanism, attention_size=attention_depth)
                helper = helper_py.TrainingHelper(decoder_inputs,
                                                  decoder_sequence_length)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))

                final_outputs, final_state = decoder.dynamic_decode(my_decoder)

            self.assertTrue(
                isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
            self.assertTrue(
                isinstance(final_state, wrapper.DynamicAttentionWrapperState))
            self.assertTrue(
                isinstance(final_state.cell_state,
                           core_rnn_cell.LSTMStateTuple))

            self.assertEqual(
                (batch_size, None, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, None),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.c.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.h.get_shape().as_list()))

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_outputs": final_outputs,
                "final_state": final_state
            })

            nest.map_structure(self.assertAllClose, expected_final_outputs,
                               sess_results["final_outputs"])
            nest.map_structure(self.assertAllClose, expected_final_state,
                               sess_results["final_state"])
예제 #13
0
  def _testWithMaybeMultiAttention(self,
                                   is_multi,
                                   create_attention_mechanisms,
                                   expected_final_output,
                                   expected_final_state,
                                   attention_mechanism_depths,
                                   alignment_history=False,
                                   expected_final_alignment_history=None,
                                   attention_layer_sizes=None,
                                   name=''):
    # Allow is_multi to be True with a single mechanism to enable test for
    # passing in a single mechanism in a list.
    assert len(create_attention_mechanisms) == 1 or is_multi
    encoder_sequence_length = [3, 2, 3, 1, 1]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9

    if attention_layer_sizes is None:
      attention_depth = encoder_output_depth * len(create_attention_mechanisms)
    else:
      # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
      attention_depth = sum([attention_layer_size or encoder_output_depth
                             for attention_layer_size in attention_layer_sizes])

    decoder_inputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, decoder_max_time,
                        input_depth).astype(np.float32),
        shape=(None, None, input_depth))
    encoder_outputs = array_ops.placeholder_with_default(
        np.random.randn(batch_size, encoder_max_time,
                        encoder_output_depth).astype(np.float32),
        shape=(None, None, encoder_output_depth))

    attention_mechanisms = [
        creator(num_units=depth,
                memory=encoder_outputs,
                memory_sequence_length=encoder_sequence_length)
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths)]

    with self.test_session(use_gpu=True) as sess:
      with vs.variable_scope(
          'root',
          initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
        cell = rnn_cell.LSTMCell(cell_depth)
        cell = wrapper.AttentionWrapper(
            cell,
            attention_mechanisms if is_multi else attention_mechanisms[0],
            attention_layer_size=(attention_layer_sizes if is_multi
                                  else attention_layer_sizes[0]),
            alignment_history=alignment_history)
        helper = helper_py.TrainingHelper(decoder_inputs,
                                          decoder_sequence_length)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))

        final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))

      self.assertEqual((batch_size, None, attention_depth),
                       tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual((batch_size, None),
                       tuple(final_outputs.sample_id.get_shape().as_list()))

      self.assertEqual((batch_size, attention_depth),
                       tuple(final_state.attention.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.c.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.h.get_shape().as_list()))

      if alignment_history:
        if is_multi:
          state_alignment_history = []
          for history_array in final_state.alignment_history:
            history = history_array.stack()
            self.assertEqual(
                (None, batch_size, None),
                tuple(history.get_shape().as_list()))
            state_alignment_history.append(history)
          state_alignment_history = tuple(state_alignment_history)
        else:
          state_alignment_history = final_state.alignment_history.stack()
          self.assertEqual(
              (None, batch_size, None),
              tuple(state_alignment_history.get_shape().as_list()))
        nest.assert_same_structure(
            cell.state_size,
            cell.zero_state(batch_size, dtypes.float32))
        # Remove the history from final_state for purposes of the
        # remainder of the tests.
        final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
      else:
        state_alignment_history = ()

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          'final_outputs': final_outputs,
          'final_state': final_state,
          'state_alignment_history': state_alignment_history,
      })

      final_output_info = nest.map_structure(get_result_summary,
                                             sess_results['final_outputs'])
      final_state_info = nest.map_structure(get_result_summary,
                                            sess_results['final_state'])
      print(name)
      print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info))
      print('expected_final_state = %s' % str(final_state_info))
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_output,
                         final_output_info)
      nest.map_structure(self.assertAllCloseOrEqual, expected_final_state,
                         final_state_info)
      if alignment_history:  # by default, the wrapper emits attention as output
        final_alignment_history_info = nest.map_structure(
            get_result_summary, sess_results['state_alignment_history'])
        print('expected_final_alignment_history = %s' %
              str(final_alignment_history_info))
        nest.map_structure(
            self.assertAllCloseOrEqual,
            # outputs are batch major but the stacked TensorArray is time major
            expected_final_alignment_history,
            final_alignment_history_info)
    def _testStepWithScheduledOutputTrainingHelper(self, use_next_input_layer):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = input_depth
        if use_next_input_layer:
            cell_depth = 6

        with self.test_session() as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            cell = core_rnn_cell.LSTMCell(cell_depth)
            half = constant_op.constant(0.5)

            next_input_layer = None
            if use_next_input_layer:
                next_input_layer = layers_core.Dense(input_depth,
                                                     use_bias=False)

            helper = helper_py.ScheduledOutputTrainingHelper(
                inputs=inputs,
                sequence_length=sequence_length,
                sampling_probability=half,
                time_major=False,
                next_input_layer=next_input_layer)

            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=cell.zero_state(dtype=dtypes.float32,
                                              batch_size=batch_size))

            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(cell_depth,
                                                 tensor_shape.TensorShape([])),
                output_size)
            self.assertEqual(
                basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
                output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(constant_op.constant(0),
                                              first_inputs, first_state)

            if use_next_input_layer:
                output_after_next_input_layer = next_input_layer(
                    step_outputs.rnn_output)

            batch_size_t = my_decoder.batch_size

            self.assertTrue(
                isinstance(first_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_state, core_rnn_cell.LSTMStateTuple))
            self.assertTrue(
                isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            sess.run(variables.global_variables_initializer())

            fetches = {
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            }
            if use_next_input_layer:
                fetches[
                    "output_after_next_input_layer"] = output_after_next_input_layer

            sess_results = sess.run(fetches)

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])

            sample_ids = sess_results["step_outputs"].sample_id
            batch_where_not_sampling = np.where(np.logical_not(sample_ids))
            batch_where_sampling = np.where(sample_ids)
            if use_next_input_layer:
                self.assertAllClose(
                    sess_results["step_next_inputs"][batch_where_sampling],
                    sess_results["output_after_next_input_layer"]
                    [batch_where_sampling])
            else:
                self.assertAllClose(
                    sess_results["step_next_inputs"][batch_where_sampling],
                    sess_results["step_outputs"].
                    rnn_output[batch_where_sampling])
            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_not_sampling],
                np.squeeze(inputs[batch_where_not_sampling, 1], axis=1))
예제 #15
0
  def _testStepWithScheduledOutputTrainingHelper(
      self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = input_depth
    if use_auxiliary_inputs:
      auxiliary_input_depth = 4
      auxiliary_inputs = np.random.randn(
          batch_size, max_time, auxiliary_input_depth).astype(np.float32)
    else:
      auxiliary_inputs = None

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      sampling_probability = constant_op.constant(sampling_probability)

      if use_next_inputs_fn:
        def next_inputs_fn(outputs):
          # Use deterministic function for test.
          samples = math_ops.argmax(outputs, axis=1)
          return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32)
      else:
        next_inputs_fn = None

      helper = helper_py.ScheduledOutputTrainingHelper(
          inputs=inputs,
          sequence_length=sequence_length,
          sampling_probability=sampling_probability,
          time_major=False,
          next_inputs_fn=next_inputs_fn,
          auxiliary_inputs=auxiliary_inputs)

      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))

      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)

      if use_next_inputs_fn:
        output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output)

      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      sess.run(variables.global_variables_initializer())

      fetches = {
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      }
      if use_next_inputs_fn:
        fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn

      sess_results = sess.run(fetches)

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])

      sample_ids = sess_results["step_outputs"].sample_id
      self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
      batch_where_not_sampling = np.where(np.logical_not(sample_ids))
      batch_where_sampling = np.where(sample_ids)

      auxiliary_inputs_to_concat = (
          auxiliary_inputs[:, 1] if use_auxiliary_inputs else
          np.array([]).reshape(batch_size, 0).astype(np.float32))

      expected_next_sampling_inputs = np.concatenate(
          (sess_results["output_after_next_inputs_fn"][batch_where_sampling]
           if use_next_inputs_fn else
           sess_results["step_outputs"].rnn_output[batch_where_sampling],
           auxiliary_inputs_to_concat[batch_where_sampling]),
          axis=-1)
      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_sampling],
          expected_next_sampling_inputs)

      self.assertAllClose(
          sess_results["step_next_inputs"][batch_where_not_sampling],
          np.concatenate(
              (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
               auxiliary_inputs_to_concat[batch_where_not_sampling]),
              axis=-1))
예제 #16
0
  def _testStepWithTrainingHelper(self, use_output_layer):
    sequence_length = [3, 4, 3, 1, 0]
    batch_size = 5
    max_time = 8
    input_depth = 7
    cell_depth = 10
    output_layer_depth = 3

    with self.session(use_gpu=True) as sess:
      inputs = np.random.randn(batch_size, max_time,
                               input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      helper = helper_py.TrainingHelper(
          inputs, sequence_length, time_major=False)
      if use_output_layer:
        output_layer = layers_core.Dense(output_layer_depth, use_bias=False)
        expected_output_depth = output_layer_depth
      else:
        output_layer = None
        expected_output_depth = cell_depth
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size),
          output_layer=output_layer)
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(expected_output_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, expected_output_depth),
                       step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      if use_output_layer:
        # The output layer was accessed
        self.assertEqual(len(output_layer.variables), 1)

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      self.assertAllEqual([False, False, False, False, True],
                          sess_results["first_finished"])
      self.assertAllEqual([False, False, False, True, True],
                          sess_results["step_finished"])
      self.assertEqual(output_dtype.sample_id,
                       sess_results["step_outputs"].sample_id.dtype)
      self.assertAllEqual(
          np.argmax(sess_results["step_outputs"].rnn_output, -1),
          sess_results["step_outputs"].sample_id)
    def _testWithAttention(self,
                           create_attention_mechanism,
                           expected_final_output,
                           expected_final_state,
                           attention_mechanism_depth=3,
                           alignment_history=False,
                           expected_final_alignment_history=None,
                           name=""):
        encoder_sequence_length = [3, 2, 3, 1, 0]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        attention_depth = 6

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanism = create_attention_mechanism(
            num_units=attention_mechanism_depth,
            memory=encoder_outputs,
            memory_sequence_length=encoder_sequence_length)

        with self.test_session(use_gpu=True) as sess:
            with vs.variable_scope(
                    "root",
                    initializer=init_ops.random_normal_initializer(stddev=0.01,
                                                                   seed=3)):
                cell = core_rnn_cell.LSTMCell(cell_depth)
                cell = wrapper.AttentionWrapper(
                    cell,
                    attention_mechanism,
                    attention_size=attention_depth,
                    alignment_history=alignment_history)
                helper = helper_py.TrainingHelper(decoder_inputs,
                                                  decoder_sequence_length)
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=cell.zero_state(dtype=dtypes.float32,
                                                  batch_size=batch_size))

                final_outputs, final_state = decoder.dynamic_decode(my_decoder)

            self.assertTrue(
                isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
            self.assertTrue(
                isinstance(final_state, wrapper.AttentionWrapperState))
            self.assertTrue(
                isinstance(final_state.cell_state,
                           core_rnn_cell.LSTMStateTuple))

            self.assertEqual(
                (batch_size, None, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, None),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.c.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state.h.get_shape().as_list()))

            if alignment_history:
                state_alignment_history = final_state.alignment_history.stack()
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
                self.assertEqual(
                    (None, batch_size, encoder_max_time),
                    tuple(state_alignment_history.get_shape().as_list()))
            else:
                state_alignment_history = ()

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                "final_outputs":
                final_outputs,
                "final_state":
                final_state,
                "state_alignment_history":
                state_alignment_history,
            })

            print("Copy/paste (%s)\nexpected_final_output = " % name,
                  sess_results["final_outputs"])
            sys.stdout.flush()
            print("Copy/paste (%s)\nexpected_final_state = " % name,
                  sess_results["final_state"])
            sys.stdout.flush()
            print(
                "Copy/paste (%s)\nexpected_final_alignment_history = " % name,
                sess_results["state_alignment_history"])
            sys.stdout.flush()
            nest.map_structure(self.assertAllClose, expected_final_output,
                               sess_results["final_outputs"])
            nest.map_structure(self.assertAllClose, expected_final_state,
                               sess_results["final_state"])
            if alignment_history:  # by default, the wrapper emits attention as output
                self.assertAllClose(
                    # outputs are batch major but the stacked TensorArray is time major
                    sess_results["state_alignment_history"],
                    expected_final_alignment_history)
예제 #18
0
  def testStepWithInferenceHelperMultilabel(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size
    start_token = 0
    end_token = 6

    start_inputs = array_ops.one_hot(
        np.ones(batch_size) * start_token,
        vocabulary_size)

    # The sample function samples independent bernoullis from the logits.
    sample_fn = (
        lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool))
    # The next inputs are a one-hot encoding of the sampled labels.
    next_inputs_fn = math_ops.to_float
    end_fn = lambda sample_ids: sample_ids[:, end_token]

    with self.session(use_gpu=True) as sess:
      with variable_scope.variable_scope(
          "testStepWithInferenceHelper",
          initializer=init_ops.constant_initializer(0.01)):
        cell = rnn_cell.LSTMCell(vocabulary_size)
        helper = helper_py.InferenceHelper(
            sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool,
            start_inputs=start_inputs, end_fn=end_fn,
            next_inputs_fn=next_inputs_fn)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))
        output_size = my_decoder.output_size
        output_dtype = my_decoder.output_dtype
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(cell_depth, cell_depth),
            output_size)
        self.assertEqual(
            basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool),
            output_dtype)

        (first_finished, first_inputs, first_state) = my_decoder.initialize()
        (step_outputs, step_state, step_next_inputs,
         step_finished) = my_decoder.step(
             constant_op.constant(0), first_inputs, first_state)
        batch_size_t = my_decoder.batch_size

        self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
        self.assertTrue(
            isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
        self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
        self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

        sess.run(variables.global_variables_initializer())
        sess_results = sess.run({
            "batch_size": batch_size_t,
            "first_finished": first_finished,
            "first_inputs": first_inputs,
            "first_state": first_state,
            "step_outputs": step_outputs,
            "step_state": step_state,
            "step_next_inputs": step_next_inputs,
            "step_finished": step_finished
        })

        sample_ids = sess_results["step_outputs"].sample_id
        self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
        expected_step_finished = sample_ids[:, end_token]
        expected_step_next_inputs = sample_ids.astype(np.float32)
        self.assertAllEqual(expected_step_finished,
                            sess_results["step_finished"])
        self.assertAllEqual(expected_step_next_inputs,
                            sess_results["step_next_inputs"])
예제 #19
0
    num_units=num_units, memory=enc_output, memory_sequence_length=src_len)
attnRNNCell = attention_wrapper.AttentionWrapper(
    cell=rnncell,
    attention_mechanism=attention_mechanism,
    alignment_history=True)

# training
tgt_len = [5, 6, 2, 7, 4]
tgt_max_times = 7
tgt_inputs = tf.random.normal((batch_size, tgt_max_times, num_units),
                              dtype=tf.float32)
training_helper = helper_py.TrainingHelper(tgt_inputs, tgt_len)

# train helper
train_decoder = basic_decoder.BasicDecoder(
    cell=attnRNNCell,
    helper=training_helper,
    initial_state=attnRNNCell.zero_state(batch_size, tf.float32))

# inference
embedding = tf.get_variable("embedding",
                            shape=(10, 16),
                            initializer=tf.random_uniform_initializer())
infer_helper = helper_py.GreedyEmbeddingHelper(
    embedding=embedding,  # 可以是callable,也可以是embedding矩阵
    start_tokens=tf.zeros([batch_size], dtype=tf.int32),
    end_token=9)
infer_decoder = basic_decoder.BasicDecoder(
    cell=attnRNNCell,
    helper=infer_helper,
    initial_state=attnRNNCell.zero_state(batch_size, tf.float32))
final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode(