def _testDynamicDecodeRNN(self, time_major, has_attention):
        encoder_sequence_length = np.array([3, 2, 3, 1, 1])
        decoder_sequence_length = np.array([2, 0, 1, 2, 3])
        batch_size = 5
        decoder_max_time = 4
        input_depth = 7
        cell_depth = 9
        attention_depth = 6
        vocab_size = 20
        end_token = vocab_size - 1
        start_token = 0
        embedding_dim = 50
        max_out = max(decoder_sequence_length)
        output_layer = layers_core.Dense(vocab_size,
                                         use_bias=True,
                                         activation=None)
        beam_width = 3

        with self.test_session() as sess:
            batch_size_tensor = constant_op.constant(batch_size)
            embedding = np.random.randn(vocab_size,
                                        embedding_dim).astype(np.float32)
            cell = rnn_cell.LSTMCell(cell_depth)
            initial_state = cell.zero_state(batch_size, dtypes.float32)
            if has_attention:
                inputs = array_ops.placeholder_with_default(
                    np.random.randn(batch_size, decoder_max_time,
                                    input_depth).astype(np.float32),
                    shape=(None, None, input_depth))
                tiled_inputs = beam_search_decoder.tile_batch(
                    inputs, multiplier=beam_width)
                tiled_sequence_length = beam_search_decoder.tile_batch(
                    encoder_sequence_length, multiplier=beam_width)
                attention_mechanism = attention_wrapper.BahdanauAttention(
                    num_units=attention_depth,
                    memory=tiled_inputs,
                    memory_sequence_length=tiled_sequence_length)
                initial_state = beam_search_decoder.tile_batch(
                    initial_state, multiplier=beam_width)
                cell = attention_wrapper.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=attention_depth,
                    alignment_history=False)
            cell_state = cell.zero_state(dtype=dtypes.float32,
                                         batch_size=batch_size_tensor *
                                         beam_width)
            if has_attention:
                cell_state = cell_state.clone(cell_state=initial_state)
            bsd = beam_search_decoder.BeamSearchDecoder(
                cell=cell,
                embedding=embedding,
                start_tokens=array_ops.fill([batch_size_tensor], start_token),
                end_token=end_token,
                initial_state=cell_state,
                beam_width=beam_width,
                output_layer=output_layer,
                length_penalty_weight=0.0)

            final_outputs, final_state, final_sequence_lengths = (
                decoder.dynamic_decode(bsd,
                                       output_time_major=time_major,
                                       maximum_iterations=max_out))

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertTrue(
                isinstance(final_outputs,
                           beam_search_decoder.FinalBeamSearchDecoderOutput))
            self.assertTrue(
                isinstance(final_state,
                           beam_search_decoder.BeamSearchDecoderState))

            beam_search_decoder_output = final_outputs.beam_search_decoder_output
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(beam_search_decoder_output.scores.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, None, beam_width)),
                tuple(final_outputs.predicted_ids.get_shape().as_list()))

            sess.run(variables.global_variables_initializer())
            sess_results = sess.run({
                'final_outputs':
                final_outputs,
                'final_state':
                final_state,
                'final_sequence_lengths':
                final_sequence_lengths
            })

            max_sequence_length = np.max(
                sess_results['final_sequence_lengths'])

            # A smoke test
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                scores.shape)
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                sess_results['final_outputs'].beam_search_decoder_output.
                predicted_ids.shape)
Example #2
0
    def get_decoder(cell,
                    y__ref_flag,
                    x_ref_flag,
                    tgt_ref_flag,
                    beam_width=None):
        output_layer_params = ({
            "output_layer": tf.identity
        } if Config.copy_flag else {
            "vocab_size": vocab.size
        })

        if Config.attn_flag:  # attention
            if Config.attn_x and Config.attn_y_:
                memory = tf.concat(
                    [
                        sent_enc_outputs[y__ref_flag],
                        sd_enc_outputs[x_ref_flag]
                    ],
                    axis=1,
                )
                memory_sequence_length = None
            elif Config.attn_y_:
                memory = sent_enc_outputs[y__ref_flag]
                memory_sequence_length = sent_sequence_length[y__ref_flag]
            elif Config.attn_x:
                memory = sd_enc_outputs[x_ref_flag]
                memory_sequence_length = sd_sequence_length[x_ref_flag]
            else:
                raise Exception(
                    "Must specify either y__ref_flag or x_ref_flag.")
            attention_decoder = tx.modules.AttentionRNNDecoder(
                cell=cell,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                hparams=Config.config_model.attention_decoder,
                **output_layer_params)
            if not Config.copy_flag:
                return attention_decoder
            cell = (attention_decoder.cell if beam_width is None else
                    attention_decoder._get_beam_search_cell(beam_width))

        if Config.copy_flag:  # copynet
            kwargs = {
                "y__ids": sent_ids[y__ref_flag][:, 1:],
                "y__states": sent_enc_outputs[y__ref_flag][:, 1:],
                "y__lengths": sent_sequence_length[y__ref_flag] - 1,
                "x_ids": sd_ids[x_ref_flag]["value"],
                "x_states": sd_enc_outputs[x_ref_flag],
                "x_lengths": sd_sequence_length[x_ref_flag],
            }

            if tgt_ref_flag is not None:
                kwargs.update({
                    "input_ids":
                    data_batch["{}_text_ids".format(
                        y_strs[tgt_ref_flag])][:, :-1]
                })

            memory_prefixes = []

            if Config.copy_y_:
                memory_prefixes.append("y_")

            if Config.copy_x:
                memory_prefixes.append("x")

            if beam_width is not None:
                kwargs = {
                    name: tile_batch(value, beam_width)
                    for name, value in kwargs.items()
                }

            def get_get_copy_scores(memory_ids_states_lengths, output_size):
                memory_copy_states = [
                    tf.layers.dense(
                        memory_states,
                        units=output_size,
                        activation=None,
                        use_bias=False,
                    ) for _, memory_states, _ in memory_ids_states_lengths
                ]

                def get_copy_scores(query, coverities=None):
                    ret = []

                    if Config.copy_y_:
                        memory = memory_copy_states[len(ret)]
                        if coverities is not None:
                            memory = memory + tf.layers.dense(
                                coverities[len(ret)],
                                units=output_size,
                                activation=None,
                                use_bias=False,
                            )
                        memory = tf.nn.tanh(memory)
                        ret_y_ = tf.einsum("bim,bm->bi", memory, query)
                        ret.append(ret_y_)

                    if Config.copy_x:
                        memory = memory_copy_states[len(ret)]
                        if coverities is not None:
                            memory = memory + tf.layers.dense(
                                coverities[len(ret)],
                                units=output_size,
                                activation=None,
                                use_bias=False,
                            )
                        memory = tf.nn.tanh(memory)
                        ret_x = tf.einsum("bim,bm->bi", memory, query)
                        ret.append(ret_x)

                    return ret

                return get_copy_scores

            covrity_dim = (Config.config_model.coverage_state_dim
                           if Config.coverage else None)
            coverity_rnn_cell_hparams = (Config.config_model.coverage_rnn_cell
                                         if Config.coverage else None)
            cell = CopyNetWrapper(
                cell=cell,
                vocab_size=vocab.size,
                memory_ids_states_lengths=[
                    tuple(kwargs["{}_{}".format(prefix, s)]
                          for s in ("ids", "states", "lengths"))
                    for prefix in memory_prefixes
                ],
                input_ids=kwargs["input_ids"]
                if tgt_ref_flag is not None else None,
                get_get_copy_scores=get_get_copy_scores,
                coverity_dim=covrity_dim,
                coverity_rnn_cell_hparams=coverity_rnn_cell_hparams,
                disabled_vocab_size=Config.disabled_vocab_size,
                eps=Config.eps,
            )

        decoder = tx.modules.BasicRNNDecoder(
            cell=cell,
            hparams=Config.config_model.decoder,
            **output_layer_params)
        return decoder
  def _testDynamicDecodeRNN(self, time_major, has_attention,
                            with_alignment_history=False):
    encoder_sequence_length = np.array([3, 2, 3, 1, 1])
    decoder_sequence_length = np.array([2, 0, 1, 2, 3])
    batch_size = 5
    decoder_max_time = 4
    input_depth = 7
    cell_depth = 9
    attention_depth = 6
    vocab_size = 20
    end_token = vocab_size - 1
    start_token = 0
    embedding_dim = 50
    max_out = max(decoder_sequence_length)
    output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None)
    beam_width = 3

    with self.cached_session() as sess:
      batch_size_tensor = constant_op.constant(batch_size)
      embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
      cell = rnn_cell.LSTMCell(cell_depth)
      initial_state = cell.zero_state(batch_size, dtypes.float32)
      coverage_penalty_weight = 0.0
      if has_attention:
        coverage_penalty_weight = 0.2
        inputs = array_ops.placeholder_with_default(
            np.random.randn(batch_size, decoder_max_time, input_depth).astype(
                np.float32),
            shape=(None, None, input_depth))
        tiled_inputs = beam_search_decoder.tile_batch(
            inputs, multiplier=beam_width)
        tiled_sequence_length = beam_search_decoder.tile_batch(
            encoder_sequence_length, multiplier=beam_width)
        attention_mechanism = attention_wrapper.BahdanauAttention(
            num_units=attention_depth,
            memory=tiled_inputs,
            memory_sequence_length=tiled_sequence_length)
        initial_state = beam_search_decoder.tile_batch(
            initial_state, multiplier=beam_width)
        cell = attention_wrapper.AttentionWrapper(
            cell=cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=attention_depth,
            alignment_history=with_alignment_history)
      cell_state = cell.zero_state(
          dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
      if has_attention:
        cell_state = cell_state.clone(cell_state=initial_state)
      bsd = beam_search_decoder.BeamSearchDecoder(
          cell=cell,
          embedding=embedding,
          start_tokens=array_ops.fill([batch_size_tensor], start_token),
          end_token=end_token,
          initial_state=cell_state,
          beam_width=beam_width,
          output_layer=output_layer,
          length_penalty_weight=0.0,
          coverage_penalty_weight=coverage_penalty_weight)

      final_outputs, final_state, final_sequence_lengths = (
          decoder.dynamic_decode(
              bsd, output_time_major=time_major, maximum_iterations=max_out))

      def _t(shape):
        if time_major:
          return (shape[1], shape[0]) + shape[2:]
        return shape

      self.assertTrue(
          isinstance(final_outputs,
                     beam_search_decoder.FinalBeamSearchDecoderOutput))
      self.assertTrue(
          isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))

      beam_search_decoder_output = final_outputs.beam_search_decoder_output
      self.assertEqual(
          _t((batch_size, None, beam_width)),
          tuple(beam_search_decoder_output.scores.get_shape().as_list()))
      self.assertEqual(
          _t((batch_size, None, beam_width)),
          tuple(final_outputs.predicted_ids.get_shape().as_list()))

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          'final_outputs': final_outputs,
          'final_state': final_state,
          'final_sequence_lengths': final_sequence_lengths
      })

      max_sequence_length = np.max(sess_results['final_sequence_lengths'])

      # A smoke test
      self.assertEqual(
          _t((batch_size, max_sequence_length, beam_width)),
          sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
      self.assertEqual(
          _t((batch_size, max_sequence_length, beam_width)), sess_results[
              'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
Example #4
0
def impl(features, mode, hp):
    contexts = features[
        'contexts']  # batch_size,max_con_length(with query),max_sen_length
    context_utterance_length = features[
        'context_utterance_length']  # batch_size,max_con_length
    context_length = features['context_length']  # batch_size
    if mode == modekeys.TRAIN or mode == modekeys.EVAL:
        response_in = features['response_in']  # batch,max_res_sen
        response_out = features['response_out']  # batch,max_res_sen
        response_mask = features[
            'response_mask']  # batch,max_res_sen, tf.float32
        batch_size = hp.batch_size
    else:
        batch_size = context_utterance_length.shape[0].value

    with tf.variable_scope('embedding_layer', reuse=tf.AUTO_REUSE) as vs:
        embedding_w = get_embedding_matrix(hp.word_dim, mode, hp.vocab_size,
                                           random_seed, hp.word_embed_path,
                                           hp.vocab_path)
        contexts = tf.nn.embedding_lookup(embedding_w, contexts,
                                          'context_embedding')
        if mode == modekeys.TRAIN or mode == modekeys.EVAL:
            response_in = tf.nn.embedding_lookup(embedding_w, response_in,
                                                 'response_in_embedding')

    with tf.variable_scope('utterance_encoding_layer',
                           reuse=tf.AUTO_REUSE) as vs:
        kernel_initializer = tf.random_normal_initializer(mean=0.0,
                                                          stddev=0.1,
                                                          seed=random_seed + 1)
        bias_initializer = tf.zeros_initializer()
        fw_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.word_rnn_num_units,
                                         kernel_initializer=kernel_initializer,
                                         bias_initializer=bias_initializer)
        kernel_initializer = tf.random_normal_initializer(mean=0.0,
                                                          stddev=0.1,
                                                          seed=random_seed - 1)
        bias_initializer = tf.zeros_initializer()
        bw_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.word_rnn_num_units,
                                         kernel_initializer=kernel_initializer,
                                         bias_initializer=bias_initializer)

        context_t = tf.transpose(contexts, perm=[
            1, 0, 2, 3
        ])  # max_con_length(with query),batch_size,max_sen_length
        context_utterance_length_t = tf.transpose(
            context_utterance_length, perm=[1,
                                            0])  # max_con_length, batch_size
        a = tf.split(context_t, hp.max_context_length,
                     axis=0)  # 1,batch_size,max_sen_length
        b = tf.split(context_utterance_length_t, hp.max_context_length,
                     axis=0)  # 1,batch_size

        utterance_encodings = []
        for utterance, length in zip(a, b):
            utterance = tf.squeeze(utterance, axis=0)
            length = tf.squeeze(length, axis=0)
            utterance_hidden_states, _ = tf.nn.bidirectional_dynamic_rnn(
                fw_cell,
                bw_cell,
                utterance,
                sequence_length=length,
                initial_state_fw=fw_cell.zero_state(batch_size, tf.float32),
                initial_state_bw=bw_cell.zero_state(batch_size, tf.float32))
            utterance_encoding = tf.concat(utterance_hidden_states, axis=2)
            utterance_encodings.append(
                tf.expand_dims(utterance_encoding, axis=0))

        utterance_encodings = tf.concat(
            utterance_encodings,
            axis=0)  # max_con_length,batch_size,max_sen,2*word_rnn_num_units

    with tf.variable_scope('hierarchical_attention_layer',
                           reuse=tf.AUTO_REUSE) as vs:
        if mode == modekeys.PREDICT and hp.beam_width != 0:
            utterance_encodings = tf.transpose(utterance_encodings,
                                               perm=[1, 0, 2, 3])
            utterance_encodings = tile_batch(utterance_encodings,
                                             multiplier=hp.beam_width)
            utterance_encodings = tf.transpose(utterance_encodings,
                                               perm=[1, 0, 2, 3])

            context_utterance_length_t = tf.transpose(
                context_utterance_length_t, perm=[1, 0])
            context_utterance_length_t = tile_batch(context_utterance_length_t,
                                                    multiplier=hp.beam_width)
            context_utterance_length_t = tf.transpose(
                context_utterance_length_t, perm=[1, 0])

            context_length = tile_batch(context_length,
                                        multiplier=hp.beam_width)

        attention_mechanism = ContextAttentionMechanism(
            context_attn_units=hp.context_attn_units,
            utte_attn_units=hp.utte_attn_units,
            context=utterance_encodings,
            context_utterance_length=context_utterance_length_t,
            max_context_length=hp.max_context_length,
            context_rnn_num_units=hp.context_rnn_num_units,
            context_actual_length=context_length)

    with tf.variable_scope('decoder_layer', reuse=tf.AUTO_REUSE) as vs:
        kernel_initializer = tf.random_normal_initializer(mean=0.0,
                                                          stddev=0.1,
                                                          seed=random_seed + 3)
        bias_initializer = tf.zeros_initializer()
        decoder_cell = tf.nn.rnn_cell.GRUCell(
            num_units=hp.decoder_rnn_num_units,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer)
        attn_cell = AttentionWrapper(
            decoder_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=None,
            output_attention=False)  # output_attention should be False
        output_layer = layers_core.Dense(
            units=hp.vocab_size, activation=None,
            use_bias=False)  # should use no activation and no bias

        if mode == modekeys.TRAIN:
            sequence_length = tf.constant(value=hp.max_sentence_length,
                                          dtype=tf.int32,
                                          shape=[batch_size])
            helper = TrainingHelper(inputs=response_in,
                                    sequence_length=sequence_length)
            decoder = BasicDecoder(cell=attn_cell,
                                   helper=helper,
                                   initial_state=attn_cell.zero_state(
                                       batch_size, tf.float32),
                                   output_layer=output_layer)
            final_outputs, final_state, final_sequence_lengths = dynamic_decode(
                decoder=decoder,
                impute_finished=True,
                parallel_iterations=32,
                swap_memory=True)
            logits = final_outputs.rnn_output
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=response_out, logits=logits)
            cross_entropy = tf.multiply(cross_entropy, response_mask)
            cross_entropy = tf.reduce_sum(cross_entropy, axis=1)
            loss = tf.reduce_mean(cross_entropy)
            l2_norm = hp.lambda_l2 * tf.add_n([
                tf.nn.l2_loss(var)
                for var in tf.trainable_variables() if 'bias' not in var.name
            ])
            loss = loss + l2_norm

            debug_tensors = []
            return loss, debug_tensors
        elif mode == modekeys.EVAL:
            sequence_length = tf.constant(value=hp.max_sentence_length,
                                          dtype=tf.int32,
                                          shape=[batch_size])
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=response_in, sequence_length=sequence_length)
            decoder = BasicDecoder(cell=attn_cell,
                                   helper=helper,
                                   initial_state=attn_cell.zero_state(
                                       batch_size, tf.float32),
                                   output_layer=output_layer)
            final_outputs, final_state, final_sequence_lengths = dynamic_decode(
                decoder=decoder,
                impute_finished=True,
                parallel_iterations=32,
                swap_memory=True)
            logits = final_outputs.rnn_output
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=response_out, logits=logits)
            cross_entropy = tf.reduce_mean(cross_entropy * response_mask)
            ppl = tf.exp(cross_entropy)
            return ppl
        elif mode == modekeys.PREDICT:
            if hp.beam_width == 0:
                helper = GreedyEmbeddingHelper(embedding=embedding_w,
                                               start_tokens=tf.constant(
                                                   1,
                                                   tf.int32,
                                                   shape=[batch_size]),
                                               end_token=2)
                initial_state = attn_cell.zero_state(batch_size=batch_size,
                                                     dtype=tf.float32)
                decoder = BasicDecoder(cell=attn_cell,
                                       helper=helper,
                                       initial_state=initial_state,
                                       output_layer=output_layer)
                final_outputs, final_state, final_sequence_lengths = dynamic_decode(
                    decoder, maximum_iterations=hp.max_sentence_length)
                results = {}
                results['response_ids'] = final_outputs.sample_id
                results['response_lens'] = final_sequence_lengths
                return results
            else:
                decoder_initial_state = attn_cell.zero_state(
                    batch_size=batch_size * hp.beam_width, dtype=tf.float32)
                decoder = BeamSearchDecoder(
                    cell=attn_cell,
                    embedding=embedding_w,
                    start_tokens=tf.constant(1, tf.int32, shape=[batch_size]),
                    end_token=2,
                    initial_state=decoder_initial_state,
                    beam_width=hp.beam_width,
                    output_layer=output_layer)
                final_outputs, final_state, final_sequence_lengths = dynamic_decode(
                    decoder,
                    impute_finished=False,
                    maximum_iterations=hp.max_sentence_length)

                final_outputs = final_outputs.predicted_ids  # b,s,beam_width
                final_outputs = tf.transpose(final_outputs,
                                             perm=[0, 2, 1])  # b,beam_width,s
                # predicted_length = final_state.lengths #b,s
                predicted_length = None

                results = {}
                results['response_ids'] = final_outputs
                results['response_lens'] = None
                return results
Example #5
0
    def _build_decoder(self, encoder_outputs, encoder_state):
        with tf.name_scope("seq_decoder"):
            batch_size = self.batch_size
            # sequence_length = tf.fill([self.batch_size], self.num_steps)
            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                sequence_length = self.iterator.target_length
            else:
                sequence_length = self.iterator.source_length
            if (self.mode !=
                    tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1:
                batch_size = batch_size * self.beam_width
                encoder_outputs = beam_search_decoder.tile_batch(
                    encoder_outputs, multiplier=self.beam_width)
                encoder_state = nest.map_structure(
                    lambda s: beam_search_decoder.tile_batch(
                        s, self.beam_width), encoder_state)
                sequence_length = beam_search_decoder.tile_batch(
                    sequence_length, multiplier=self.beam_width)

            single_cell = single_rnn_cell(self.hparams.unit_type,
                                          self.num_units, self.dropout)
            decoder_cell = MultiRNNCell(
                [single_cell for _ in range(self.num_layers_decoder)])
            decoder_cell = InputProjectionWrapper(decoder_cell,
                                                  num_proj=self.num_units)
            attention_mechanism = create_attention_mechanism(
                self.hparams.attention_mechanism,
                self.num_units,
                memory=encoder_outputs,
                source_sequence_length=sequence_length)
            decoder_cell = wrapper.AttentionWrapper(
                decoder_cell,
                attention_mechanism,
                attention_layer_size=self.num_units,
                output_attention=True,
                alignment_history=False)

            # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다.
            initial_state = decoder_cell.zero_state(batch_size=batch_size,
                                                    dtype=tf.float32)
            embeddings_decoder = tf.get_variable(
                "embedding_decoder",
                [self.num_decoder_symbols, self.num_units],
                initializer=self.initializer,
                dtype=tf.float32)
            output_layer = Dense(units=self.num_decoder_symbols,
                                 use_bias=True,
                                 name="output_layer")

            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                decoder_inputs = tf.nn.embedding_lookup(
                    embeddings_decoder, self.iterator.target_in)
                decoder_helper = helper.TrainingHelper(
                    decoder_inputs, sequence_length=sequence_length)

                dec = basic_decoder.BasicDecoder(decoder_cell,
                                                 decoder_helper,
                                                 initial_state,
                                                 output_layer=output_layer)
                final_outputs, final_state, _ = decoder.dynamic_decode(dec)
                output_ids = final_outputs.rnn_output
                outputs = final_outputs.sample_id
            else:

                def embedding_fn(inputs):
                    return tf.nn.embedding_lookup(embeddings_decoder, inputs)

                decoding_length_factor = 2.0
                max_encoder_length = tf.reduce_max(self.iterator.source_length)
                maximum_iterations = tf.to_int32(
                    tf.round(
                        tf.to_float(max_encoder_length) *
                        decoding_length_factor))

                tgt_sos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)),
                    tf.int32)
                tgt_eos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)),
                    tf.int32)
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id

                if self.beam_width == 1:
                    decoder_helper = helper.GreedyEmbeddingHelper(
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token)
                    dec = basic_decoder.BasicDecoder(decoder_cell,
                                                     decoder_helper,
                                                     initial_state,
                                                     output_layer=output_layer)
                else:
                    dec = beam_search_decoder.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=initial_state,
                        output_layer=output_layer,
                        beam_width=self.beam_width)
                final_outputs, final_state, _ = decoder.dynamic_decode(
                    dec,
                    # swap_memory=True,
                    maximum_iterations=maximum_iterations)
                if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1:
                    output_ids = final_outputs.sample_id
                    outputs = final_outputs.rnn_output
                else:
                    output_ids = final_outputs.predicted_ids
                    outputs = final_outputs.beam_search_decoder_output.scores

            return output_ids, outputs
Example #6
0
    def _create_decoder_cell(self):
        enc_outputs, enc_states, enc_seq_len = self.enc_outputs, self.enc_states, self.enc_seq_len
        if self.use_beam_search:
            enc_outputs = tile_batch(enc_outputs,
                                     multiplier=self.cfg.beam_size)
            enc_states = nest.map_structure(
                lambda s: tile_batch(s, self.cfg.beam_size), enc_states)
            enc_seq_len = tile_batch(self.enc_seq_len,
                                     multiplier=self.cfg.beam_size)
        batch_size = self.batch_size * self.cfg.beam_size if self.use_beam_search else self.batch_size
        with tf.variable_scope("attention"):
            if self.cfg.attention == "luong":  # Luong attention mechanism
                attention_mechanism = LuongAttention(
                    num_units=self.cfg.num_units,
                    memory=enc_outputs,
                    memory_sequence_length=enc_seq_len)
            else:  # default using Bahdanau attention mechanism
                attention_mechanism = BahdanauAttention(
                    num_units=self.cfg.num_units,
                    memory=enc_outputs,
                    memory_sequence_length=enc_seq_len)

        def cell_input_fn(
            inputs, attention
        ):  # define cell input function to keep input/output dimension same
            # reference: https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/AttentionWrapper
            if not self.cfg.use_attention_input_feeding:
                return inputs
            input_project = tf.layers.Dense(self.cfg.num_units,
                                            dtype=tf.float32,
                                            name='attn_input_feeding')
            return input_project(tf.concat([inputs, attention], axis=-1))

        if self.cfg.top_attention:  # apply attention mechanism only on the top decoder layer
            cells = [
                self._create_rnn_cell() for _ in range(self.cfg.num_layers)
            ]
            cells[-1] = AttentionWrapper(
                cells[-1],
                attention_mechanism=attention_mechanism,
                name="Attention_Wrapper",
                attention_layer_size=self.cfg.num_units,
                initial_cell_state=enc_states[-1],
                cell_input_fn=cell_input_fn)
            initial_state = [state for state in enc_states]
            initial_state[-1] = cells[-1].zero_state(batch_size=batch_size,
                                                     dtype=tf.float32)
            dec_init_states = tuple(initial_state)
            cells = MultiRNNCell(cells)
        else:
            cells = MultiRNNCell(
                [self._create_rnn_cell() for _ in range(self.cfg.num_layers)])
            cells = AttentionWrapper(cells,
                                     attention_mechanism=attention_mechanism,
                                     name="Attention_Wrapper",
                                     attention_layer_size=self.cfg.num_units,
                                     initial_cell_state=enc_states,
                                     cell_input_fn=cell_input_fn)
            dec_init_states = cells.zero_state(
                batch_size=batch_size,
                dtype=tf.float32).clone(cell_state=enc_states)
        return cells, dec_init_states