Beispiel #1
0
 def _build_decoder(self, decoder_cell, batch_size, lstm_holistic_features,
                    cnn_fmap):
     embedding_fn = functools.partial(tf.one_hot, depth=self.num_classes)
     output_layer = Compute2dAttentionLayer(512, self.num_classes, cnn_fmap)
     if self._is_training:
         train_helper = seq2seq.TrainingHelper(
             embedding_fn(self._groundtruth_dict['decoder_inputs']),
             sequence_length=self._groundtruth_dict['decoder_lengths'],
             time_major=False)
         decoder = seq2seq.BasicDecoder(
             cell=decoder_cell,
             helper=train_helper,
             initial_state=lstm_holistic_features,
             output_layer=output_layer)
     else:
         lstm0_state_tile = tf.nn.rnn_cell.LSTMStateTuple(
             tf.tile(lstm_holistic_features[0].c, [self._beam_width, 1]),
             tf.tile(lstm_holistic_features[0].h, [self._beam_width, 1]))
         lstm1_state_tile = tf.nn.rnn_cell.LSTMStateTuple(
             tf.tile(lstm_holistic_features[1].c, [self._beam_width, 1]),
             tf.tile(lstm_holistic_features[1].h, [self._beam_width, 1]))
         lstm_holistic_features_tile = (lstm0_state_tile, lstm1_state_tile)
         decoder = seq2seq.BeamSearchDecoder(
             cell=decoder_cell,
             embedding=embedding_fn,
             start_tokens=tf.fill([batch_size], self.start_label),
             end_token=self.end_label,
             initial_state=lstm_holistic_features_tile,
             beam_width=self._beam_width,
             output_layer=output_layer,
             length_penalty_weight=0.0)
     return decoder
Beispiel #2
0
    def model(self):
        with tf.variable_scope("encoder"):
            encoder_cell = self._create_rnn_cell()
            source_embedding = tf.get_variable(name="source_embedding",
                                               shape=[self.source_vocab_size, self.embedding_size],
                                               initializer=tf.initializers.truncated_normal())
            encoder_embedding_inputs = tf.nn.embedding_lookup(source_embedding, self.source_input)
            encoder_outputs, encoder_states = tf.nn.dynamic_rnn(cell=encoder_cell,
                                                                inputs=encoder_embedding_inputs,
                                                                dtype=tf.float32)
        with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            if self.mode=="test":
                encoder_states = seq2seq.tile_batch(encoder_states, self.beam_size)
            decoder_cell = self._create_rnn_cell()
            decoder_cell = rnn.DropoutWrapper(decoder_cell,output_keep_prob=0.5)

            if self.mode=="test":
                batch_size = self.batch_size*self.beam_size
            else:
                batch_size = self.batch_size
            #decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size,dtype=tf.float32)

            output_layer = tf.layers.Dense(units=self.target_vocab_size,
                                           kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
            target_embedding = tf.get_variable(name="target_embedding",
                                               shape=[self.target_vocab_size, self.embedding_size])
            if self.mode == "train":
                self.mask = tf.sequence_mask(self.target_length,self.max_target_length,dtype=tf.float32)
                del_end = tf.strided_slice(self.target_input,[0,0],[self.batch_size,-1],[1,1])
                decoder_input = tf.concat([tf.fill([self.batch_size, 1],2),del_end],axis=1)
                decoder_input_embedding = tf.nn.embedding_lookup(target_embedding,decoder_input)
                training_helper = seq2seq.TrainingHelper(inputs=decoder_input_embedding,
                                                         sequence_length=tf.fill([self.batch_size],self.max_target_length))
                training_decoder = seq2seq.BasicDecoder(cell=decoder_cell,
                                                        helper=training_helper,
                                                        initial_state=encoder_states,
                                                        output_layer=output_layer)
                decoder_outputs,_,_ = seq2seq.dynamic_decode(decoder=training_decoder,output_time_major=False,
                                                             impute_finished=True,
                                                             maximum_iterations=self.max_target_length)
                self.decoder_logits_train = tf.identity(decoder_outputs.rnn_output)
                self.decoder_predict_train = tf.argmax(self.decoder_logits_train,axis=-1)
                self.loss_op = tf.reduce_mean(tf.losses.softmax_cross_entropy(
                    onehot_labels=tf.one_hot(self.target_input, depth=self.target_vocab_size),
                    logits=self.decoder_logits_train, weights=self.mask))
                optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
                trainable_params = tf.trainable_variables()
                gradients = tf.gradients(self.loss_op, trainable_params)
                clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm)
                self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params))
            elif self.mode =="test":
                start_tokens = tf.fill([self.batch_size], value=2)
                end_token = 3
                inference_decoder = seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=target_embedding,
                                                              start_tokens=start_tokens, end_token=end_token,
                                                              initial_state=encoder_states,
                                                              beam_width=self.beam_size, output_layer=output_layer)
                decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=inference_decoder, maximum_iterations=self.max_target_length)
                print(decoder_outputs.predicted_ids.get_shape().as_list())
                self.decoder_predict_decode = decoder_outputs.predicted_ids[:, :, 0]
Beispiel #3
0
 def _build_decoder(self, decoder_cell, batch_size):
     embedding_fn = functools.partial(tf.one_hot, depth=self.num_classes)
     output_layer = tf.layers.Dense(
         self.num_classes,
         activation=None,
         use_bias=True,
         kernel_initializer=tf.variance_scaling_initializer(),
         bias_initializer=tf.zeros_initializer())
     if self._is_training:
         train_helper = seq2seq.TrainingHelper(
             embedding_fn(self._groundtruth_dict['decoder_inputs']),
             sequence_length=self._groundtruth_dict['decoder_lengths'],
             time_major=False)
         decoder = seq2seq.BasicDecoder(
             cell=decoder_cell,
             helper=train_helper,
             initial_state=decoder_cell.zero_state(batch_size, tf.float32),
             output_layer=output_layer)
     else:
         decoder = seq2seq.BeamSearchDecoder(
             cell=decoder_cell,
             embedding=embedding_fn,
             start_tokens=tf.fill([batch_size], self.start_label),
             end_token=self.end_label,
             initial_state=decoder_cell.zero_state(
                 batch_size * self._beam_width, tf.float32),
             beam_width=self._beam_width,
             output_layer=output_layer,
             length_penalty_weight=0.0)
     return decoder
Beispiel #4
0
    def _build_decoder_test_beam_search(self):
        r"""
        Builds a beam search test decoder
        """
        if self._hparams.enable_attention is True:
            cells, initial_state = self._add_attention(self._decoder_cells, beam_search=True)
        else:  # does the non-attentive beam decoder need tile_batch ?
            cells = self._decoder_cells

            decoder_initial_state_tiled = seq2seq.tile_batch(  # guess so ? it compiles without it too
                self._decoder_initial_state, multiplier=self._hparams.beam_width)
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([self._batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.6,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :, 0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
Beispiel #5
0
 def inference_decode_layer(self, start_token, dec_cell, end_token,
                            output_layer):
     start_tokens = tf.tile(tf.constant([start_token], dtype=tf.int32),
                            [self.batch_size],
                            name='start_token')
     tiled_enc_output = seq2seq.tile_batch(self.enc_output,
                                           multiplier=self.Beam_width)
     tiled_enc_state = seq2seq.tile_batch(self.enc_state,
                                          multiplier=self.Beam_width)
     tiled_source_len = seq2seq.tile_batch(self.source_len,
                                           multiplier=self.Beam_width)
     atten_mech = seq2seq.BahdanauAttention(self.hidden_dim * 2,
                                            tiled_enc_output,
                                            tiled_source_len,
                                            normalize=True)
     decoder_att = seq2seq.AttentionWrapper(dec_cell, atten_mech,
                                            self.hidden_dim * 2)
     initial_state = decoder_att.zero_state(
         self.batch_size * self.Beam_width,
         tf.float32).clone(cell_state=tiled_enc_state)
     decoder = seq2seq.BeamSearchDecoder(decoder_att,
                                         self.embeddings,
                                         start_tokens,
                                         end_token,
                                         initial_state,
                                         beam_width=self.Beam_width,
                                         output_layer=output_layer)
     infer_logits, _, _ = seq2seq.dynamic_decode(decoder, False, False,
                                                 self.max_target_len)
     return infer_logits
def beam_eval_decoder(agenda,
                      embeddings,
                      extended_base_words,
                      oov,
                      start_token_id,
                      stop_token_id,
                      base_sent_hiddens,
                      base_length,
                      vocab_size,
                      attn_dim,
                      hidden_dim,
                      num_layer,
                      max_sentence_length,
                      beam_width,
                      swap_memory,
                      enable_dropout=False,
                      dropout_keep=1.,
                      no_insert_delete_attn=False):
    with tf.variable_scope(OPS_NAME, 'decoder', reuse=True):
        true_batch_size = tf.shape(base_sent_hiddens)[0]

        tiled_agenda = seq2seq.tile_batch(agenda, beam_width)
        tiled_extended_base_words = seq2seq.tile_batch(extended_base_words,
                                                       beam_width)
        tiled_oov = seq2seq.tile_batch(oov, beam_width)

        tiled_base_sent = seq2seq.tile_batch(base_sent_hiddens, beam_width)
        tiled_base_lengths = seq2seq.tile_batch(base_length, beam_width)

        start_token_id = tf.cast(start_token_id, tf.int32)
        stop_token_id = tf.cast(stop_token_id, tf.int32)

        cell, zero_states = create_decoder_cell(
            tiled_agenda,
            tiled_extended_base_words,
            tiled_oov,
            tiled_base_sent,
            tiled_base_lengths,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_layer,
            enable_dropout=enable_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn,
            beam_width=beam_width)

        decoder = seq2seq.BeamSearchDecoder(cell,
                                            create_embedding_fn(vocab_size),
                                            tf.fill([true_batch_size],
                                                    start_token_id),
                                            stop_token_id,
                                            zero_states,
                                            beam_width=beam_width,
                                            length_penalty_weight=0.0)

        return seq2seq.dynamic_decode(decoder,
                                      maximum_iterations=max_sentence_length,
                                      swap_memory=swap_memory)
Beispiel #7
0
    def _build_decoder_beam_search(self):

        batch_size, _ = tf.unstack(tf.shape(self._labels))

        attention_mechanisms, layer_sizes = self._create_attention_mechanisms(
            beam_search=True)

        decoder_initial_state_tiled = seq2seq.tile_batch(
            self._decoder_initial_state, multiplier=self._hparams.beam_width)

        if self._hparams.enable_attention is True:

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=decoder_initial_state_tiled,
                alignment_history=self._hparams.write_attention_alignment,
                output_attention=self._output_attention)

            initial_state = attention_cells.zero_state(
                dtype=self._hparams.dtype,
                batch_size=batch_size * self._hparams.beam_width)

            initial_state = initial_state.clone(
                cell_state=decoder_initial_state_tiled)

            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.5,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        if self._hparams.write_attention_alignment is True:
            self.attention_summary = self._create_attention_alignments_summary(
                states)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :,
                                                             0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
        self.beam_search_output = outputs.beam_search_decoder_output
Beispiel #8
0
    def _build_decoder_test_beam_search(self):
        r"""
        Builds a beam search test decoder
        """
        if self._hparams.enable_attention is True:
            cells, initial_state = add_attention(
                cells=self._decoder_cells,
                attention_types=self._hparams.attention_type[1],
                num_units=self._hparams.decoder_units_per_layer[-1],
                memory=self._encoder_memory,
                memory_len=self._encoder_features_len,
                beam_search=True,
                batch_size=self._batch_size,
                beam_width=self._hparams.beam_width,
                initial_state=self._decoder_initial_state,
                mode=self._mode,
                dtype=self._hparams.dtype,
                fusion_type='linear_fusion',
                write_attention_alignment=self._hparams.
                write_attention_alignment)
        else:  # does the non-attentive beam decoder need tile_batch ?
            cells = self._decoder_cells

            decoder_initial_state_tiled = seq2seq.tile_batch(  # guess so ? it compiles without it too
                self._decoder_initial_state,
                multiplier=self._hparams.beam_width)
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([self._batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.6,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        if self._hparams.write_attention_alignment is True:
            self.attention_summary, self.attention_alignment = self._create_attention_alignments_summary(
                states)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :,
                                                             0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
        self.beam_search_output = outputs.beam_search_decoder_output
Beispiel #9
0
    def _build_decoder(self):
        with tf.variable_scope("dialog_decoder"):
            with tf.variable_scope("decoder_output_projection"):
                output_layer = layers_core.Dense(
                    self.config.vocab_size, use_bias=False, name="output_projection")

            with tf.variable_scope("decoder_rnn"):
                dec_cell, dec_init_state = self._build_decoder_cell(enc_outputs=self.encoder_outputs,
                                                                    enc_state=self.encoder_state)

                # Training or Eval
                if self.mode != ModelMode.infer:  # not infer, do decode turn by turn
                    resp_emb_inp = tf.nn.embedding_lookup(self.decoder_embeddings, self.target_input)
                    helper = tc_seq2seq.TrainingHelper(resp_emb_inp, self.target_length)
                    decoder = tc_seq2seq.BasicDecoder(
                        cell=dec_cell,
                        helper=helper,
                        initial_state=dec_init_state,
                        output_layer=output_layer
                    )

                    dec_outputs, dec_state, _ = tc_seq2seq.dynamic_decode(decoder)
                    sample_id = dec_outputs.sample_id
                    logits = dec_outputs.rnn_output

                else:
                    beam_width = self.config.beam_size
                    length_penalty_weight = self.config.length_penalty_weight
                    maximum_iterations = tf.to_int32(self.config.infer_max_len)
                    start_tokens = tf.fill([self.batch_size], self.config.sos_idx)
                    end_token = self.config.eos_idx

                    # beam size
                    decoder = tc_seq2seq.BeamSearchDecoder(
                        cell=dec_cell,
                        embedding=self.decoder_embeddings,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=dec_init_state,
                        beam_width=beam_width,
                        output_layer=output_layer,
                        length_penalty_weight=length_penalty_weight)

                    dec_outputs, dec_state, _ = tc_seq2seq.dynamic_decode(
                        decoder,
                        maximum_iterations=maximum_iterations,
                    )
                    logits = tf.no_op()
                    sample_id = dec_outputs.predicted_ids

                self.logits = logits
                self.sample_id = sample_id
Beispiel #10
0
def beam_eval_decoder(agenda, embeddings, start_token_id, stop_token_id,
                      base_sent_hiddens, insert_word_embeds, delete_word_embeds,
                      base_length, iw_length, dw_length,
                      attn_dim, hidden_dim, num_layer, maximum_iterations, beam_width, swap_memory,
                      enable_dropout=False, dropout_keep=1., no_insert_delete_attn=False):
    with tf.variable_scope(OPS_NAME, 'decoder', reuse=True):
        true_batch_size = tf.shape(base_sent_hiddens)[0]

        tiled_agenda = seq2seq.tile_batch(agenda, beam_width)

        tiled_base_sent = seq2seq.tile_batch(base_sent_hiddens, beam_width)
        tiled_insert_embeds = seq2seq.tile_batch(insert_word_embeds, beam_width)
        tiled_delete_embeds = seq2seq.tile_batch(delete_word_embeds, beam_width)

        tiled_src_lengths = seq2seq.tile_batch(base_length, beam_width)
        tiled_iw_lengths = seq2seq.tile_batch(iw_length, beam_width)
        tiled_dw_lengths = seq2seq.tile_batch(dw_length, beam_width)

        start_token_id = tf.cast(start_token_id, tf.int32)
        stop_token_id = tf.cast(stop_token_id, tf.int32)

        cell = create_decoder_cell(
            tiled_agenda,
            tiled_base_sent, tiled_insert_embeds, tiled_delete_embeds,
            tiled_src_lengths, tiled_iw_lengths, tiled_dw_lengths,
            attn_dim, hidden_dim, num_layer,
            enable_dropout=enable_dropout, dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn
        )

        output_layer = DecoderOutputLayer(embeddings, beam_decoder=True)
        zero_states = create_trainable_zero_state(cell, true_batch_size, beam_width)

        decoder = seq2seq.BeamSearchDecoder(
            cell,
            embeddings,
            tf.fill([true_batch_size], start_token_id),
            stop_token_id,
            zero_states,
            beam_width=beam_width,
            output_layer=output_layer,
            length_penalty_weight=0.0
        )

        return seq2seq.dynamic_decode(decoder, maximum_iterations=maximum_iterations, swap_memory=swap_memory)
Beispiel #11
0
    def get_beam_ids(self, cell, projection_layer):
        initial_state = cell.zero_state(self.batch_size *
                                        self.config.BEAM_WIDTH,
                                        dtype=tf.float32)

        if self.config.LEN_EMB_SIZE > 0:
            output_seq_len = seq2seq.tile_batch(
                self.output_len, multiplier=self.config.BEAM_WIDTH)
            cell = LenControlWrapper(cell,
                                     output_seq_len,
                                     self.len_embeddings,
                                     initial_cell_state=initial_state)
            initial_state = cell.zero_state(self.batch_size *
                                            self.config.BEAM_WIDTH,
                                            dtype=tf.float32)

        latent_variables = seq2seq.tile_batch(
            self.latent_variables, multiplier=self.config.BEAM_WIDTH)
        cell = AlignmentWrapper(cell,
                                latent_variables,
                                initial_cell_state=initial_state)
        initial_state = cell.zero_state(self.batch_size *
                                        self.config.BEAM_WIDTH,
                                        dtype=tf.float32)

        if not self.is_training:
            decoder = seq2seq.BeamSearchDecoder(
                cell,
                self.embedding,
                self.go_input(),
                self.eos_idx,
                initial_state=initial_state,
                beam_width=self.config.BEAM_WIDTH,
                output_layer=projection_layer)
            outputs, _, seq_len = seq2seq.dynamic_decode(
                decoder, maximum_iterations=tf.reduce_max(self.output_len))
            return outputs.predicted_ids[:, :, 0]
def _build_decoder_action(model, dialogue_state, hparams, start_token,
                          end_token, output_layer):
  """build the decoder for action states."""

  iterator = model.iterator

  start_token_id = tf.cast(
      model.vocab_table.lookup(tf.constant(start_token)), tf.int32)
  end_token_id = tf.cast(
      model.vocab_table.lookup(tf.constant(end_token)), tf.int32)

  start_tokens = tf.fill([model.batch_size], start_token_id)
  end_token = end_token_id

  # kb is not used again
  ## Decoder.
  with tf.variable_scope("action_decoder") as decoder_scope:
    # we initialize the cell with the last layer of the last hidden state
    cell, decoder_initial_state = _build_action_decoder_cell(
        model, hparams, dialogue_state, model.global_gpu_num)
    model.global_gpu_num += 1
    ## Train or eval
    # situation one, for train, eval, mutable train
    # decoder_emp_inp: [max_time, batch_size, num_units]
    action = iterator.action
    # shift action
    paddings = tf.constant([[0, 0], [1, 0]])
    action = tf.pad(action, paddings, "CONSTANT", constant_values=0)[:, :-1]
    decoder_emb_inp = tf.nn.embedding_lookup(model.embedding_decoder,
                                             action)

    # Helper
    helper_train = seq2seq.TrainingHelper(
        decoder_emb_inp, iterator.action_len, time_major=False)

    # Decoder
    my_decoder_train = seq2seq.BasicDecoder(
        cell, helper_train, decoder_initial_state, output_layer)

    # Dynamic decoding
    outputs_train, _, _ = seq2seq.dynamic_decode(
        my_decoder_train,
        output_time_major=False,
        swap_memory=True,
        scope=decoder_scope)

    sample_id_train = outputs_train.sample_id
    logits_train = outputs_train.rnn_output
    # inference

    beam_width = hparams.beam_width
    length_penalty_weight = hparams.length_penalty_weight

    if model.mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
      my_decoder_infer = seq2seq.BeamSearchDecoder(
          cell=cell,
          embedding=model.embedding_decoder,
          start_tokens=start_tokens,
          end_token=end_token,
          initial_state=decoder_initial_state,
          beam_width=beam_width,
          output_layer=output_layer,
          length_penalty_weight=length_penalty_weight)
    else:
      # Helper
      if model.mode in dialogue_utils.self_play_modes:
        helper_infer = seq2seq.SampleEmbeddingHelper(
            model.embedding_decoder, start_tokens, end_token)
      else:
        helper_infer = seq2seq.GreedyEmbeddingHelper(
            model.embedding_decoder, start_tokens, end_token)

      # Decoder
      my_decoder_infer = seq2seq.BasicDecoder(
          cell,
          helper_infer,
          decoder_initial_state,
          output_layer=output_layer  # applied per timestep
      )

    # Dynamic decoding
    outputs_infer, _, _ = seq2seq.dynamic_decode(
        my_decoder_infer,
        maximum_iterations=hparams.len_action,
        output_time_major=False,
        swap_memory=True,
        scope=decoder_scope)

    if model.mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
      logits_infer = tf.no_op()
      sample_id_infer = outputs_infer.predicted_ids
    else:
      logits_infer = outputs_infer.rnn_output
      sample_id_infer = outputs_infer.sample_id

  return logits_train, logits_infer, sample_id_train, sample_id_infer
    def create_model_predict(self, input, mode='decode'):
        use_beam_search = False
        if self.params.beam_with > 1:
            use_beam_search = True
        with tf.variable_scope("attetnion_seq2seq", reuse=tf.AUTO_REUSE):
            embeddings_matrix = self._create_embedding()

            keep_prob = 1 - self.params.dropout_rate
            batch_size = tf.shape(input)[0]
            # encoder
            encoder_outputs, encoder_last_states, encoder_inputs_length = self._create_encoder(
                embeddings_matrix, input, keep_prob)

            # decoder
            with tf.variable_scope('decoder'):
                # # Output projection layer to convert cell_outpus to logits
                output_layer = Dense(self.params.vocab_size,
                                     name='output_project')
                input_layer = Dense(self.params.hidden_units * 2,
                                    dtype=tf.float32,
                                    name='input_projection')
                decoder_cell, decoder_initial_state = create_decoder_cell(
                    enc_outputs=encoder_outputs,
                    enc_states=encoder_last_states,
                    enc_seq_len=encoder_inputs_length,
                    num_layers=self.params.depth,
                    num_units=self.params.hidden_units * 2,
                    keep_prob=keep_prob,
                    use_residual=self.params.use_residual,
                    use_beam_search=use_beam_search,
                    beam_size=self.params.beam_with,
                    batch_size=batch_size,
                    top_attention=self.params.top_attention)

                # Start_tokens: [batch_size,] `int32` vector
                start_tokens = tf.ones([
                    batch_size,
                ], tf.int32) * data_utils.GO_ID
                end_token = data_utils.EOS_ID

                def embed_and_input_proj(inputs):
                    return input_layer(
                        tf.nn.embedding_lookup(embeddings_matrix, inputs))

                if self.params.beam_with <= 1:
                    decode_helper = seq2seq.GreedyEmbeddingHelper(
                        start_tokens=start_tokens,
                        end_token=end_token,
                        embedding=embed_and_input_proj)
                    inference_decoder = seq2seq.BasicDecoder(
                        cell=decoder_cell,
                        helper=decode_helper,
                        initial_state=decoder_initial_state,
                        output_layer=output_layer)
                    decoder_output, _, _ = seq2seq.dynamic_decode(
                        decoder=inference_decoder,
                        output_time_major=False,
                        impute_finished=True,
                        maximum_iterations=self.params.max_seq_length)
                else:
                    inference_decoder = seq2seq.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embed_and_input_proj,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=decoder_initial_state,
                        beam_width=self.params.beam_with,
                        output_layer=output_layer)

                decoder_output, _, _ = seq2seq.dynamic_decode(
                    decoder=inference_decoder,
                    output_time_major=False,
                    maximum_iterations=self.params.max_seq_length)

                if self.params.beam_with <= 1:
                    decoder_predict = tf.expand_dims(decoder_output.sample_id,
                                                     -1)
                else:
                    decoder_predict = decoder_output.predicted_ids

        decoder_predict = tf.identity(decoder_predict, 'predicts')
        return decoder_predict
Beispiel #14
0
    def build_decoder(self):
        with tf.variable_scope("decoder"):
            decoder_cell, decoder_initial_state = self.build_decoder_cell()

            # start tokens : [batch_size], which is fed to BeamsearchDecoder during inference
            start_tokens = tf.ones([self.batch_size],
                                   dtype=tf.int32) * data_util.ID_GO
            end_token = data_util.ID_EOS
            input_layer = Dense(self.state_size * 2, dtype=tf.float32,
                                name="input_layer")
            output_layer = Dense(self.decoder_vocab_size,
                                 name="output_projection")
            if self.mode == "train":
                # feed ground truth decoder input token every time step
                decoder_input_lookup = tf.nn.embedding_lookup(
                    self.embedding_matrix, self.decoder_input)
                decoder_input_lookup = input_layer(decoder_input_lookup)
                training_helper = seq2seq.TrainingHelper(
                    inputs=decoder_input_lookup,
                    sequence_length=self.decoder_train_len,
                    name="training_helper")
                training_decoder = seq2seq.BasicDecoder(cell=decoder_cell,
                                                        initial_state=decoder_initial_state,
                                                        helper=training_helper,
                                                        output_layer=output_layer)

                # decoder_outputs_train: BasicDecoderOutput
                #                        namedtuple(rnn_outputs, sample_id)
                # decoder_outputs_train.rnn_output: [batch_size, max_time_step + 1, num_decoder_symbols] if output_time_major=False
                #                                   [max_time_step + 1, batch_size, num_decoder_symbols] if output_time_major=True
                # decoder_outputs_train.sample_id: [batch_size], tf.int32
                max_decoder_len = tf.reduce_max(self.decoder_train_len)
                decoder_outputs_train, final_state, _ = seq2seq.dynamic_decode(
                    training_decoder, impute_finished=True, swap_memory=True,
                    maximum_iterations=max_decoder_len)
                self.decoder_logits_train = tf.identity(
                    decoder_outputs_train.rnn_output)
                decoder_pred = tf.argmax(self.decoder_logits_train, axis=2)
                # sequence mask for get valid sequence except zero padding
                weights = tf.sequence_mask(self.decoder_len,
                                           maxlen=max_decoder_len,
                                           dtype=tf.float32)
                # compute cross entropy loss for all sequence prediction and ignore loss from zero padding
                self.loss = seq2seq.sequence_loss(
                    logits=self.decoder_logits_train,
                    targets=self.decoder_target,
                    weights=weights, average_across_batch=True,
                    average_across_timesteps=True)
                tf.summary.scalar("loss", self.loss)

                with tf.variable_scope("train_optimizer") and tf.device(
                        "/device:GPU:1"):
                    # use AdamOptimizer and clip gradient by max_norm 5.0
                    # use global step for counting every iteration
                    params = tf.trainable_variables()
                    gradients = tf.gradients(self.loss, params)
                    clipped_gradients, _ = tf.clip_by_global_norm(gradients,
                                                                  5.0)
                    learning_rate = tf.train.exponential_decay(self.lr,
                                                               self.global_step,
                                                               10000, 0.96)
                    opt = tf.train.AdagradOptimizer(learning_rate)

                    self.train_op = opt.apply_gradients(
                        zip(clipped_gradients, params),
                        global_step=self.global_step)

            elif self.mode == "test":
                def embedding_proj(inputs):
                    return input_layer(
                        tf.nn.embedding_lookup(self.embedding_matrix,
                                               inputs))

                inference_decoder = seq2seq.BeamSearchDecoder(cell=decoder_cell,
                                                              embedding=embedding_proj,
                                                              start_tokens=start_tokens,
                                                              end_token=end_token,
                                                              initial_state=decoder_initial_state,
                                                              beam_width=self.beam_depth,
                                                              output_layer=output_layer)

                # For GreedyDecoder, return
                # decoder_outputs_decode: BasicDecoderOutput instance
                #                         namedtuple(rnn_outputs, sample_id)
                # decoder_outputs_decode.rnn_output: [batch_size, max_time_step, num_decoder_symbols] 	if output_time_major=False
                #                                    [max_time_step, batch_size, num_decoder_symbols] 	if output_time_major=True
                # decoder_outputs_decode.sample_id: [batch_size, max_time_step], tf.int32		if output_time_major=False
                #                                   [max_time_step, batch_size], tf.int32               if output_time_major=True

                # For BeamSearchDecoder, return
                # decoder_outputs_decode: FinalBeamSearchDecoderOutput instance
                #                         namedtuple(predicted_ids, beam_search_decoder_output)
                # decoder_outputs_decode.predicted_ids: [batch_size, max_time_step, beam_width] if output_time_major=False
                #                                       [max_time_step, batch_size, beam_width] if output_time_major=True
                # decoder_outputs_decode.beam_search_decoder_output: BeamSearchDecoderOutput instance
                #                                                    namedtuple(scores, predicted_ids, parent_ids)
                with tf.device("/device:GPU:1"):
                    decoder_outputs, decoder_last_state, decoder_output_length = \
                        seq2seq.dynamic_decode(decoder=inference_decoder,
                                               output_time_major=False,
                                               swap_memory=True,
                                               maximum_iterations=self.max_iter)
                    self.decoder_pred_test = decoder_outputs.predicted_ids
        attention_mechanism = seq2seq.BahdanauAttention(
            num_units=hidden_dim * 2,
            memory=encoder_outputs,
            memory_sequence_length=encoder_inputs_length)
    #decoder_cell = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(hidden_dim*2) for _ in range(num_layers)])
    decoder_cell = seq2seq.AttentionWrapper(
        cell=global_decoder_cell,
        attention_mechanism=attention_mechanism,
        attention_layer_size=hidden_dim * 2)

    inference_decoder = seq2seq.BeamSearchDecoder(
        cell=decoder_cell,
        embedding=no_op_embedding,
        start_tokens=tf.fill([batch_size], 12),
        end_token=0,
        initial_state=decoder_cell.zero_state(
            batch_size * beam_width,
            tf.float32).clone(cell_state=encoder_last_state),
        beam_width=beam_width,
        #initial_state = decoder_cell_inf.zero_state(batch_size = batch_size, dtype = tf.float32)
        output_layer=projection_layer)

    print(inference_decoder)
    with tf.variable_scope('decode_with_shared_attention', reuse=True):
        inference_decoder_output, _, _ = seq2seq.dynamic_decode(
            decoder=inference_decoder,
            impute_finished=False,
            maximum_iterations=tf.reduce_max(encoder_inputs_length))

    for var in tf.trainable_variables():
        print(var)
Beispiel #16
0
    def _build_sentence_decoder(self, inputs, context_encoder_outputs,
                                sentence_encoder_final_states,
                                sentence_encoder_outputs):
        batch_size = self._batch_size
        num_sentence = self._num_sentence

        word_embedding = model_helper.create_word_embedding(
            num_vocab=self.hparams.num_vocab,
            embedding_dim=self.hparams.word_embedding_dim,
            name='decoder/word_embedding',
            pretrained_word_matrix=self.hparams.pretrained_word_path)

        # tile_batch in inference mode
        beam_width = self.hparams.beam_width
        if self.mode == tf.contrib.learn.ModeKeys.INFER:
            # only decode last timestep
            if 'lstm' in self.hparams.rnn_cell_type.lower():
                batched_sentence_encoder_states = []
                for encoder_state in sentence_encoder_final_states:
                    target_shape = tf.stack([batch_size, num_sentence, -1])
                    c = s2s.tile_batch(
                        tf.reshape(encoder_state.c, target_shape)[:, -1, :],
                        beam_width)
                    h = s2s.tile_batch(
                        tf.reshape(encoder_state.h, target_shape)[:, -1, :],
                        beam_width)
                    batched_sentence_encoder_states.append(
                        tf.contrib.rnn.LSTMStateTuple(c=c, h=h))
            else:
                batched_sentence_encoder_states = [
                    s2s.tile_batch(
                        tf.reshape(encoder_state,
                                   tf.stack([batch_size, num_sentence,
                                             -1]))[:, -1, :], beam_width)
                    for encoder_state in sentence_encoder_final_states
                ]
            sentence_encoder_final_states = tuple(
                batched_sentence_encoder_states)

            sentence_encoder_outputs = s2s.tile_batch(
                tf.reshape(
                    sentence_encoder_outputs,
                    tf.stack([
                        batch_size, num_sentence, -1,
                        self.hparams.num_rnn_units
                    ]))[:, -1, :, :], beam_width)
            source_lengths = s2s.tile_batch(inputs.src_lengths[:, -1],
                                            beam_width)

            context_encoder_outputs = tf.reshape(
                context_encoder_outputs,
                tf.stack(
                    [batch_size, num_sentence,
                     self.hparams.num_rnn_units]))[:, -1, :]
            context_encoder_outputs = tf.tile(
                tf.expand_dims(context_encoder_outputs, axis=1),
                [1, beam_width, 1])
            effective_batch_size = self._batch_size * beam_width
        else:
            source_lengths = tf.reshape(inputs.src_lengths, [-1])
            context_encoder_outputs.set_shape(
                [None, self.hparams.num_rnn_units])
            effective_batch_size = self._batch_size * self._num_sentence

        # Current strategy: No residual layers at decoder
        attention_mechanism = model_helper.create_attention_mechanism(
            attention_option=self.hparams.attention_type,
            num_units=self.hparams.num_rnn_units,
            memory=sentence_encoder_outputs,
            source_length=source_lengths)
        decoder_cell = s2s.AttentionWrapper(
            model_helper.create_rnn_cell(
                cell_type=self.hparams.rnn_cell_type,
                num_layers=self.hparams.num_rnn_layers,
                num_units=self.hparams.num_rnn_units,
                dropout_keep_prob=self._dropout_keep_prob,
                num_residual_layers=0),
            attention_mechanism,
            attention_layer_size=self.hparams.num_rnn_units,
            alignment_history=False,
            name="attention")

        decoder_initial_state = decoder_cell.zero_state(
            effective_batch_size, tf.float32)
        decoder_initial_state = decoder_initial_state.clone(
            cell_state=sentence_encoder_final_states)

        with tf.variable_scope('output_projection'):
            output_layer = layers_core.Dense(self.hparams.num_vocab,
                                             name="output_projection")
            self.output_layer = output_layer

        if self.mode in {
                tf.contrib.learn.ModeKeys.TRAIN, tf.contrib.learn.ModeKeys.EVAL
        }:
            decoder_input_tokens = tf.reshape(
                inputs.targets_in, tf.stack([batch_size * num_sentence, -1]))
            decoder_inputs = tf.nn.embedding_lookup(word_embedding,
                                                    decoder_input_tokens)
            target_lengths = tf.reshape(inputs.tgt_lengths, [-1])

            if self.mode == tf.contrib.learn.ModeKeys.TRAIN and False:
                sampling_probability = 1.0 - tf.train.exponential_decay(
                    1.0,
                    self.global_step,
                    self.hparams.scheduled_sampling_decay_steps,
                    self.hparams.scheduled_sampling_decay_rate,
                    staircase=True,
                    name='scheduled_sampling_prob')
                helper = s2s.ScheduledEmbeddingTrainingHelper(
                    inputs=decoder_inputs,
                    sequence_length=target_lengths,
                    embedding=word_embedding,
                    sampling_probability=sampling_probability,
                    name='scheduled_sampling_helper')
            else:
                helper = s2s.TrainingHelper(
                    inputs=decoder_inputs,
                    sequence_length=target_lengths,
                    name='training_helper',
                )
            decoder = s2s.BasicDecoder(decoder_cell,
                                       helper,
                                       decoder_initial_state,
                                       output_layer=None)
            final_outputs, final_state, _ = dynamic_decode_with_concat(
                decoder, context_encoder_outputs, swap_memory=True)
            logits = final_outputs.rnn_output
            sample_id = final_outputs.sample_id

        else:
            sos_id = tf.cast(self.vocab_table.lookup(tf.constant(dataset.SOS)),
                             tf.int32)
            eos_id = tf.cast(self.vocab_table.lookup(tf.constant(dataset.EOS)),
                             tf.int32)
            sos_ids = tf.fill([batch_size], sos_id)
            decoder = s2s.BeamSearchDecoder(
                cell=decoder_cell,
                embedding=word_embedding,
                start_tokens=sos_ids,
                end_token=eos_id,
                initial_state=decoder_initial_state,
                beam_width=beam_width,
                output_layer=self.output_layer)
            final_outputs, final_state, _ = dynamic_decode_with_concat(
                decoder,
                context_encoder_outputs,
                maximum_iterations=self.hparams.target_max_length,
                swap_memory=True)
            logits = final_outputs.beam_search_decoder_output.scores
            sample_id = final_outputs.predicted_ids

        return logits, final_state, sample_id
Beispiel #17
0
    def decode(self, encoder_outputs, encoder_state, source_sequence_length):
        with tf.variable_scope("Decoder") as scope:
            beam_width = self.beam_width
            decoder_type = self.decoder_type
            seq_max_len = self.seq_max_len
            batch_size = tf.shape(encoder_outputs)[0]

            if self.path_embed_method == "lstm":
                self.decoder_cell = self._build_decode_cell()
                if self.mode == "test" and beam_width > 0:
                    memory = seq2seq.tile_batch(self.encoder_outputs, multiplier=beam_width)
                    source_sequence_length = seq2seq.tile_batch(self.source_sequence_length, multiplier=beam_width)
                    encoder_state = seq2seq.tile_batch(self.encoder_state, multiplier=beam_width)
                    batch_size = self.batch_size * beam_width
                else:
                    memory = encoder_outputs
                    source_sequence_length = source_sequence_length
                    encoder_state = encoder_state

                attention_mechanism = seq2seq.BahdanauAttention(self.hidden_layer_dim, memory,
                                                                memory_sequence_length=source_sequence_length)
                self.decoder_cell = seq2seq.AttentionWrapper(self.decoder_cell, attention_mechanism,
                                                             attention_layer_size=self.hidden_layer_dim)
                self.decoder_initial_state = self.decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state)

            projection_layer = Dense(self.word_vocab_size, use_bias=False)

            """For training the model"""
            if self.mode == "train":
                decoder_train_helper = tf.contrib.seq2seq.TrainingHelper(self.decoder_train_inputs_embedded,
                                                                         self.decoder_train_length)
                decoder_train = seq2seq.BasicDecoder(self.decoder_cell, decoder_train_helper,
                                                     self.decoder_initial_state,
                                                     projection_layer)
                decoder_outputs_train, decoder_states_train, decoder_seq_len_train = seq2seq.dynamic_decode(decoder_train)
                decoder_logits_train = decoder_outputs_train.rnn_output
                self.decoder_logits_train = tf.reshape(decoder_logits_train, [batch_size, -1, self.word_vocab_size])

            """For test the model"""
            # if self.mode == "infer" or self.if_pred_on_dev:
            if decoder_type == "greedy":
                decoder_infer_helper = seq2seq.GreedyEmbeddingHelper(self.word_embeddings,
                                                                     tf.ones([batch_size], dtype=tf.int32),
                                                                     self.EOS)
                decoder_infer = seq2seq.BasicDecoder(self.decoder_cell, decoder_infer_helper,
                                                     self.decoder_initial_state, projection_layer)
            elif decoder_type == "beam":
                decoder_infer = seq2seq.BeamSearchDecoder(cell=self.decoder_cell, embedding=self.word_embeddings,
                                                          start_tokens=tf.ones([batch_size], dtype=tf.int32),
                                                          end_token=self.EOS,
                                                          initial_state=self.decoder_initial_state,
                                                          beam_width=beam_width,
                                                          output_layer=projection_layer)

            decoder_outputs_infer, decoder_states_infer, decoder_seq_len_infer = seq2seq.dynamic_decode(decoder_infer,
                                                                                                        maximum_iterations=seq_max_len)

            if decoder_type == "beam":
                self.decoder_logits_infer = tf.no_op()
                self.sample_id = decoder_outputs_infer.predicted_ids

            elif decoder_type == "greedy":
                self.decoder_logits_infer = decoder_outputs_infer.rnn_output
                self.sample_id = decoder_outputs_infer.sample_id
Beispiel #18
0
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 dropout,
                 num_layers,
                 max_gradient_norm,
                 batch_size,
                 learning_rate,
                 lr_decay_factor,
                 max_target_length,
                 max_source_length,
                 decoder_mode=False):
        '''
        vocab_size: number of vocab tokens
        buckets: buckets of max sequence lengths
        hidden_size: dimension of hidden layers
        num_layers: number of hidden layers
        max_gradient_norm: maximum gradient magnitude
        batch_size: number of training examples fed to network at once
        learning_rate: starting learning rate of network
        lr_decay_factor: amount by which to decay learning rate
        num_samples: number of samples for sampled softmax
        decoder_mode: Whether to build backpass nodes or not
        '''
        GO_ID = config.GO_ID
        EOS_ID = config.EOS_ID
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.global_step = tf.Variable(0, trainable=False)
        self.learning_rate = learning_rate
        self.encoder_inputs = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='encoder_inputs')
        self.source_lengths = tf.placeholder(shape=(None, ),
                                             dtype=tf.int32,
                                             name='source_lengths')

        self.decoder_targets = tf.placeholder(shape=(None, None),
                                              dtype=tf.int32,
                                              name='decoder_targets')
        self.target_lengths = tf.placeholder(shape=(None, ),
                                             dtype=tf.int32,
                                             name="target_lengths")

        with tf.variable_scope('embeddings') as scope:
            embeddings = tf.Variable(tf.random_uniform(
                [vocab_size, hidden_size], -1.0, 1.0),
                                     dtype=tf.float32)
            encoder_inputs_embedded = tf.nn.embedding_lookup(
                embeddings, self.encoder_inputs)
            targets_embedding = tf.nn.embedding_lookup(embeddings,
                                                       self.decoder_targets)

        with tf.variable_scope('encoder') as scope:
            encoder_cell = rnn.LSTMCell(hidden_size)
            encoder_cell = rnn.DropoutWrapper(encoder_cell,
                                              input_keep_prob=dropout)

            encoder_cell = rnn.MultiRNNCell([encoder_cell] * num_layers)
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                cell=encoder_cell,
                inputs=encoder_inputs_embedded,
                sequence_length=self.source_lengths,
                dtype=tf.float32,
                time_major=False)

        with tf.variable_scope('decoder') as scope:
            decoder_cell = rnn.LSTMCell(hidden_size)
            decoder_cell = rnn.DropoutWrapper(decoder_cell,
                                              input_keep_prob=dropout)

            decoder_cell = rnn.MultiRNNCell([decoder_cell] * num_layers,
                                            state_is_tuple=True)

        if decoder_mode:
            beam_width = 2
            decoder = seq2seq.BeamSearchDecoder(embedding=embeddings,
                                                start_tokens=tf.tile(
                                                    [GOD_ID], [batch_size]),
                                                end_token=EOS_ID,
                                                initial_state=encoder_state,
                                                beam_width=2)
            self.logits = final_outputs.predicted_ids
        else:

            helper = seq2seq.TrainingHelper(targets_embedding,
                                            self.target_lengths)
            decoder = seq2seq.BasicDecoder(decoder_cell, helper, encoder_state,
                                           Dense(vocab_size))


            final_outputs, final_state, final_sequence_lengths =\
                            seq2seq.dynamic_decode(decoder=decoder)

            self.logits = final_outputs.rnn_output

        if not decoder_mode:
            with tf.variable_scope("loss") as scope:
                #have to pad logits, dynamic decode produces results not consistent
                #in shape with targets
                pad_size = self.max_target_length - tf.reduce_max(
                    final_sequence_lengths)
                self.logits = tf.pad(self.logits,
                                     [[0, 0], [0, pad_size], [0, 0]])

                weights = tf.sequence_mask(lengths=final_sequence_lengths,
                                           maxlen=self.max_target_length,
                                           dtype=tf.float32,
                                           name='weights')

                x_entropy_loss = seq2seq.sequence_loss(
                    logits=self.logits,
                    targets=self.decoder_targets,
                    weights=weights)  #cross-entropy loss function

                self.loss = tf.reduce_mean(x_entropy_loss)

            optimizer = tf.train.AdamOptimizer()  #Adam optimization algorithm
            gradients = optimizer.compute_gradients(x_entropy_loss)
            capped_grads = [(tf.clip_by_value(grad, -max_gradient_norm,
                                              max_gradient_norm), var)
                            for grad, var in gradients]
            self.train_op = optimizer.apply_gradients(
                capped_grads, global_step=self.global_step)
            self.saver = tf.train.Saver(tf.global_variables())
Beispiel #19
0
    def __init__(self,
                 vocab_size,
                 embed_size,
                 num_unit,
                 latent_dim,
                 emoji_dim,
                 batch_size,
                 kl_ceiling,
                 bow_ceiling,
                 decoder_layer=1,
                 start_i=1,
                 end_i=2,
                 beam_width=0,
                 maximum_iterations=50,
                 max_gradient_norm=5,
                 lr=1e-3,
                 dropout=0.2,
                 num_gpu=2,
                 cell_type=tf.nn.rnn_cell.GRUCell,
                 is_seq2seq=False):
        self.ori_sample = None
        self.rep_sample = None
        self.out_sample = None

        self.sess = None

        self.loss_weight = tf.placeholder_with_default(0., shape=())
        self.policy_weight = tf.placeholder_with_default(1., shape=())
        self.ac_vec = tf.placeholder(tf.float32,
                                     shape=[batch_size],
                                     name="accuracy_vector")
        self.ac5_vec = tf.placeholder(tf.float32,
                                      shape=[batch_size],
                                      name="top5_accuracy_vector")

        self.is_policy = tf.placeholder_with_default(False, shape=())
        shape = [batch_size, latent_dim]
        self.rdm = tf.placeholder_with_default(np.zeros(shape,
                                                        dtype=np.float32),
                                               shape=shape)
        self.q_rdm = tf.placeholder_with_default(np.zeros(shape,
                                                          dtype=np.float32),
                                                 shape=shape)

        self.end_i = end_i
        self.batch_size = batch_size
        self.num_gpu = num_gpu
        self.num_unit = num_unit
        self.dropout = tf.placeholder_with_default(dropout, (), name="dropout")
        self.beam_width = beam_width
        self.cell_type = cell_type

        self.emoji = tf.placeholder(tf.int32, shape=[batch_size], name="emoji")
        self.ori = tf.placeholder(tf.int32,
                                  shape=[None, batch_size],
                                  name="original_tweet")  # [len, batch_size]
        self.ori_len = tf.placeholder(tf.int32,
                                      shape=[batch_size],
                                      name="original_tweet_length")
        self.rep = tf.placeholder(tf.int32,
                                  shape=[None, batch_size],
                                  name="response_tweet")
        self.rep_len = tf.placeholder(tf.int32,
                                      shape=[batch_size],
                                      name="response_tweet_length")
        self.rep_input = tf.placeholder(tf.int32,
                                        shape=[None, batch_size],
                                        name="response_start_tag")
        self.rep_output = tf.placeholder(tf.int32,
                                         shape=[None, batch_size],
                                         name="response_end_tag")

        self.reward = tf.placeholder(tf.float32,
                                     shape=[batch_size],
                                     name="reward")

        self.kl_weight = tf.placeholder_with_default(1.,
                                                     shape=(),
                                                     name="kl_weight")

        self.placeholders = [
            self.emoji, self.ori, self.ori_len, self.rep, self.rep_len,
            self.rep_input, self.rep_output
        ]

        with tf.variable_scope("embeddings"):
            embedding = Embedding(vocab_size, embed_size)

            ori_emb = embedding(
                self.ori)  # [max_len, batch_size, embedding_size]
            rep_emb = embedding(self.rep)
            rep_input_emb = embedding(self.rep_input)
            emoji_emb = embedding(self.emoji)  # [batch_size, embedding_size]

        with tf.variable_scope("original_tweet_encoder"):
            ori_encoder_output, ori_encoder_state = build_bidirectional_rnn(
                num_unit,
                ori_emb,
                self.ori_len,
                cell_type,
                num_gpu,
                self.dropout,
                base_gpu=0)
            ori_encoder_state_flat = tf.concat(
                [ori_encoder_state[0], ori_encoder_state[1]], axis=1)

        emoji_vec = tf.layers.dense(emoji_emb,
                                    emoji_dim,
                                    activation=tf.nn.tanh)
        self.emoji_vec = emoji_emb
        # emoji_vec = tf.ones([batch_size, emoji_dim], tf.float32)
        condition_flat = tf.concat([ori_encoder_state_flat, emoji_vec], axis=1)

        with tf.variable_scope("response_tweet_encoder"):
            _, rep_encoder_state = build_bidirectional_rnn(num_unit,
                                                           rep_emb,
                                                           self.rep_len,
                                                           cell_type,
                                                           num_gpu,
                                                           self.dropout,
                                                           base_gpu=2)
            rep_encoder_state_flat = tf.concat(
                [rep_encoder_state[0], rep_encoder_state[1]], axis=1)

        with tf.variable_scope("representation_network"):
            rn_input = tf.concat([rep_encoder_state_flat, condition_flat],
                                 axis=1)
            # simpler representation network
            # r_hidden = rn_input
            r_hidden = tf.layers.dense(
                rn_input,
                latent_dim,
                activation=tf.nn.relu,
                name="r_net_hidden")  # int(1.6 * latent_dim)
            r_hidden_mu = tf.layers.dense(
                r_hidden, latent_dim,
                activation=tf.nn.relu)  # int(1.3 * latent_dim)
            r_hidden_var = tf.layers.dense(r_hidden,
                                           latent_dim,
                                           activation=tf.nn.relu)
            self.mu = tf.layers.dense(r_hidden_mu,
                                      latent_dim,
                                      activation=tf.nn.tanh,
                                      name="q_mean")
            self.log_var = tf.layers.dense(r_hidden_var,
                                           latent_dim,
                                           activation=tf.nn.tanh,
                                           name="q_log_var")

        with tf.variable_scope("prior_network"):
            # simpler prior network
            # p_hidden = condition_flat
            p_hidden = tf.layers.dense(condition_flat,
                                       int(0.62 * latent_dim),
                                       activation=tf.nn.relu,
                                       name="r_net_hidden")
            p_hidden_mu = tf.layers.dense(p_hidden,
                                          int(0.77 * latent_dim),
                                          activation=tf.nn.relu)
            p_hidden_var = tf.layers.dense(p_hidden,
                                           int(0.77 * latent_dim),
                                           activation=tf.nn.relu)
            self.p_mu = tf.layers.dense(p_hidden_mu,
                                        latent_dim,
                                        activation=tf.nn.tanh,
                                        name="p_mean")
            self.p_log_var = tf.layers.dense(p_hidden_var,
                                             latent_dim,
                                             activation=tf.nn.tanh,
                                             name="p_log_var")

        with tf.variable_scope("reparameterization"):
            self.normal = tf.cond(
                self.is_policy, lambda: self.rdm,
                lambda: tf.random_normal(shape=tf.shape(self.mu)))
            self.z_sample = self.mu + tf.exp(self.log_var / 2.) * self.normal

            self.q_normal = tf.cond(
                self.is_policy, lambda: self.q_rdm,
                lambda: tf.random_normal(shape=tf.shape(self.p_mu)))
            self.q_z_sample = self.p_mu + tf.exp(
                self.p_log_var / 2.) * self.q_normal

        if is_seq2seq:
            self.z_sample = self.z_sample - self.z_sample
            self.q_z_sample = self.q_z_sample - self.q_z_sample

        with tf.variable_scope("decoder_train") as decoder_scope:
            if decoder_layer == 2:
                train_decoder_init_state = (
                    tf.concat([self.z_sample, ori_encoder_state[0], emoji_vec],
                              axis=1),
                    tf.concat([self.z_sample, ori_encoder_state[1], emoji_vec],
                              axis=1))
                dim = latent_dim + num_unit + emoji_dim
                cell = tf.nn.rnn_cell.MultiRNNCell([
                    create_rnn_cell(dim, 2, cell_type, num_gpu, self.dropout),
                    create_rnn_cell(dim, 3, cell_type, num_gpu, self.dropout)
                ])
            else:
                train_decoder_init_state = tf.concat(
                    [self.z_sample, ori_encoder_state_flat, emoji_vec], axis=1)
                dim = latent_dim + 2 * num_unit + emoji_dim
                cell = create_rnn_cell(dim, 2, cell_type, num_gpu,
                                       self.dropout)

            with tf.variable_scope("attention"):
                memory = tf.concat(
                    [ori_encoder_output[0], ori_encoder_output[1]], axis=2)
                memory = tf.transpose(memory, [1, 0, 2])

                attention_mechanism = seq2seq.LuongAttention(
                    dim,
                    memory,
                    memory_sequence_length=self.ori_len,
                    scale=True)
                # attention_mechanism = seq2seq.BahdanauAttention(
                #     num_unit, memory, memory_sequence_length=self.ori_len)

            decoder_cell = seq2seq.AttentionWrapper(
                cell, attention_mechanism, attention_layer_size=dim
            )  # TODO: add_name; what atten layer size means
            # decoder_cell = cell

            helper = seq2seq.TrainingHelper(rep_input_emb,
                                            self.rep_len + 1,
                                            time_major=True)
            projection_layer = layers_core.Dense(vocab_size,
                                                 use_bias=False,
                                                 name="output_projection")
            decoder = seq2seq.BasicDecoder(
                decoder_cell,
                helper,
                decoder_cell.zero_state(
                    batch_size,
                    tf.float32).clone(cell_state=train_decoder_init_state),
                output_layer=projection_layer)
            train_outputs, _, _ = seq2seq.dynamic_decode(
                decoder,
                output_time_major=True,
                swap_memory=True,
                scope=decoder_scope)
            self.logits = train_outputs.rnn_output

        with tf.variable_scope("decoder_infer") as decoder_scope:
            # normal_sample = tf.random_normal(shape=(batch_size, latent_dim))

            if decoder_layer == 2:
                infer_decoder_init_state = (tf.concat(
                    [self.q_z_sample, ori_encoder_state[0], emoji_vec],
                    axis=1),
                                            tf.concat([
                                                self.q_z_sample,
                                                ori_encoder_state[1], emoji_vec
                                            ],
                                                      axis=1))
            else:
                infer_decoder_init_state = tf.concat(
                    [self.q_z_sample, ori_encoder_state_flat, emoji_vec],
                    axis=1)

            start_tokens = tf.fill([batch_size], start_i)
            end_token = end_i

            if beam_width > 0:
                infer_decoder_init_state = seq2seq.tile_batch(
                    infer_decoder_init_state, multiplier=beam_width)
                decoder = seq2seq.BeamSearchDecoder(
                    cell=decoder_cell,
                    embedding=embedding.coder,
                    start_tokens=start_tokens,
                    end_token=end_token,
                    initial_state=decoder_cell.zero_state(
                        batch_size * beam_width,
                        tf.float32).clone(cell_state=infer_decoder_init_state),
                    beam_width=beam_width,
                    output_layer=projection_layer,
                    length_penalty_weight=0.0)
            else:
                helper = seq2seq.GreedyEmbeddingHelper(embedding.coder,
                                                       start_tokens, end_token)
                decoder = seq2seq.BasicDecoder(
                    decoder_cell,
                    helper,
                    decoder_cell.zero_state(
                        batch_size,
                        tf.float32).clone(cell_state=infer_decoder_init_state),
                    output_layer=projection_layer  # applied per timestep
                )

            # Dynamic decoding
            infer_outputs, _, infer_lengths = seq2seq.dynamic_decode(
                decoder,
                maximum_iterations=maximum_iterations,
                output_time_major=True,
                swap_memory=True,
                scope=decoder_scope)
            if beam_width > 0:
                self.result = infer_outputs.predicted_ids
            else:
                self.result = infer_outputs.sample_id
                self.result_lengths = infer_lengths

        with tf.variable_scope("loss"):
            max_time = tf.shape(self.rep_output)[0]
            with tf.variable_scope("reconstruction"):
                # TODO: use inference decoder's logits to compute recon_loss
                cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(  # ce = [len, batch_size]
                    labels=self.rep_output,
                    logits=self.logits)
                # rep: [len, batch_size]; logits: [len, batch_size, vocab_size]
                target_mask = tf.sequence_mask(self.rep_len + 1,
                                               max_time,
                                               dtype=self.logits.dtype)
                # time_major
                target_mask_t = tf.transpose(target_mask)  # max_len batch_size
                self.recon_losses = tf.reduce_sum(cross_entropy *
                                                  target_mask_t,
                                                  axis=0)
                self.recon_loss = tf.reduce_sum(
                    cross_entropy * target_mask_t) / batch_size

            with tf.variable_scope("latent"):
                # without prior network
                # self.kl_loss = 0.5 * tf.reduce_sum(tf.exp(self.log_var) + self.mu ** 2 - 1. - self.log_var, 0)
                self.kl_losses = 0.5 * tf.reduce_sum(
                    tf.exp(self.log_var - self.p_log_var) +
                    (self.mu - self.p_mu)**2 / tf.exp(self.p_log_var) - 1. -
                    self.log_var + self.p_log_var,
                    axis=1)
                self.kl_loss = tf.reduce_mean(self.kl_losses)

            with tf.variable_scope("bow"):
                # self.bow_loss = self.kl_weight * 0
                mlp_b = layers_core.Dense(vocab_size,
                                          use_bias=False,
                                          name="MLP_b")
                # is it a mistake that we only model on latent variable?
                latent_logits = mlp_b(
                    tf.concat(
                        [self.z_sample, ori_encoder_state_flat, emoji_vec],
                        axis=1))  # [batch_size, vocab_size]
                latent_logits = tf.expand_dims(
                    latent_logits, 0)  # [1, batch_size, vocab_size]
                latent_logits = tf.tile(
                    latent_logits,
                    [max_time, 1, 1])  # [max_time, batch_size, vocab_size]

                cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(  # ce = [len, batch_size]
                    labels=self.rep_output,
                    logits=latent_logits)
                self.bow_losses = tf.reduce_sum(cross_entropy * target_mask_t,
                                                axis=0)
                self.bow_loss = tf.reduce_sum(
                    cross_entropy * target_mask_t) / batch_size

            if is_seq2seq:
                self.kl_losses = self.kl_losses - self.kl_losses
                self.bow_losses = self.bow_losses - self.bow_losses
                self.kl_loss = self.kl_loss - self.kl_loss
                self.bow_loss = self.bow_loss - self.bow_loss

            self.losses = self.recon_losses + self.kl_losses * self.kl_weight * kl_ceiling + self.bow_losses * bow_ceiling
            self.loss = tf.reduce_mean(self.losses)

        # Calculate and clip gradients
        with tf.variable_scope("optimization"):
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(
                gradients, max_gradient_norm)

            # Optimization
            optimizer = tf.train.AdamOptimizer(lr)
            self.update_step = optimizer.apply_gradients(
                zip(clipped_gradients, params))

        with tf.variable_scope("policy_loss"):
            prob = tf.nn.softmax(
                infer_outputs.rnn_output)  # [max_len, batch_size, vocab_size]
            prob = tf.clip_by_value(prob, 1e-15, 1000.)
            output_prob = tf.reduce_max(tf.log(prob),
                                        axis=2)  # [max_len, batch_size]
            seq_log_prob = tf.reduce_sum(output_prob, axis=0)  # batch_size
            # reward = tf.nn.relu(self.reward)
            self.policy_losses = -self.reward * seq_log_prob
            self.policy_losses *= (0.5 - 1) * self.ac5_vec + 1

        with tf.variable_scope("policy_optimization"):
            # zero = tf.constant(0, dtype=tf.float32)
            # where = tf.cast(tf.less(self.reward, zero), tf.float32)
            # recon = tf.reduce_sum(self.recon_losses * where) / tf.reduce_sum(where)

            final_loss = self.policy_losses * (
                1 - self.ac_vec) * self.policy_weight
            final_loss += self.losses * self.loss_weight
            self.policy_loss = tf.reduce_mean(final_loss)

            # final_loss = self.losses * self.loss_weight + self.policy_losses * self.policy_weight
            # final_loss *= (1 - self.ac_vec)
            # self.policy_loss = tf.reduce_sum(final_loss) / tf.reduce_sum((1 - self.ac_vec))

            gradients = tf.gradients(self.policy_loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(
                gradients, max_gradient_norm)
            optimizer = tf.train.AdamOptimizer(lr)
            self.policy_step = optimizer.apply_gradients(
                zip(clipped_gradients, params))
def _build_decoder(model, encoder_outputs, encoder_state, hparams, start_token,
                   end_token, output_layer, aux_hidden_state):
  """build decoder for the seq2seq model."""

  iterator = model.iterator

  start_token_id = tf.cast(
      model.vocab_table.lookup(tf.constant(start_token)), tf.int32)
  end_token_id = tf.cast(
      model.vocab_table.lookup(tf.constant(end_token)), tf.int32)

  start_tokens = tf.fill([model.batch_size], start_token_id)
  end_token = end_token_id

  ## Decoder.
  with tf.variable_scope("decoder") as decoder_scope:
    cell, decoder_initial_state = _build_decoder_cell(
        model, hparams, encoder_state, base_gpu=model.global_gpu_num)
    model.global_gpu_num += hparams.num_layers
    # ## Train or eval

    decoder_emb_inp = tf.nn.embedding_lookup(model.embedding_decoder,
                                             iterator.target)
    # Helper
    helper_train = help_py.TrainingHelper(
        decoder_emb_inp, iterator.dialogue_len, time_major=False)

    # Decoder
    my_decoder_train = basic_decoder.BasicDecoder(
        cell,
        helper_train,
        decoder_initial_state,
        encoder_outputs,
        iterator.turns,
        output_layer=output_layer,
        aux_hidden_state=aux_hidden_state)

    # Dynamic decoding
    outputs_train, _, _ = seq2seq.dynamic_decode(
        my_decoder_train,
        output_time_major=False,
        swap_memory=True,
        scope=decoder_scope)

    sample_id_train = outputs_train.sample_id
    logits_train = outputs_train.rnn_output
    ## Inference
    # else:

    beam_width = hparams.beam_width
    length_penalty_weight = hparams.length_penalty_weight

    if model.mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
      my_decoder_infer = seq2seq.BeamSearchDecoder(
          cell=cell,
          embedding=model.embedding_decoder,
          start_tokens=start_tokens,
          end_token=end_token,
          initial_state=decoder_initial_state,
          beam_width=beam_width,
          output_layer=output_layer,
          length_penalty_weight=length_penalty_weight)
    else:
      # Helper
      if model.mode in dialogue_utils.self_play_modes:
        helper_infer = seq2seq.SampleEmbeddingHelper(
            model.embedding_decoder, start_tokens, end_token)
      else:  # inference
        helper_infer = seq2seq.GreedyEmbeddingHelper(
            model.embedding_decoder, start_tokens, end_token)

      # Decoder
      my_decoder_infer = seq2seq.BasicDecoder(
          cell,
          helper_infer,
          decoder_initial_state,
          output_layer=output_layer  # applied per timestep
      )

    # Dynamic decoding
    outputs_infer, _, _ = seq2seq.dynamic_decode(
        my_decoder_infer,
        maximum_iterations=hparams.max_inference_len,
        output_time_major=False,
        swap_memory=True,
        scope=decoder_scope)

    if model.mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0:
      logits_infer = tf.no_op()
      sample_id_infer = outputs_infer.predicted_ids
    else:
      logits_infer = outputs_infer.rnn_output
      sample_id_infer = outputs_infer.sample_id

  return logits_train, logits_infer, sample_id_train, sample_id_infer
Beispiel #21
0
    def build_decode(self):
        # build decoder and attention.
        with tf.variable_scope('decoder'):
            self.decoder_cell, self.decoder_initial_state = self.build_decode_cell(
            )

            # Input projection layer to feed embedded inputs to the cell
            # ** Essential when use_residual=True to match input/output dims
            input_layer = Dense(self.hidden_units,
                                dtype=self.dtype,
                                name='input_projection')

            # Output projection layer to convert cell_outpus to logits
            output_layer = Dense(self.num_decoder_symbols,
                                 name='output_project')

            if self.mode == 'train':
                # decoder_inputs_embedded: [batch_size, max_time_step + 1, embedding_size]
                self.decoder_inputs_embedded = tf.nn.embedding_lookup(
                    self.embeddings, self.decoder_inputs_train)

                # Embedded inputs having gone through input projection layer
                self.decoder_inputs_embedded = input_layer(
                    self.decoder_inputs_embedded)

                # Helper to feed inputs for training: read inputs from dense ground truth vectors
                training_helper = seq2seq.TrainingHelper(
                    inputs=self.decoder_inputs_embedded,
                    sequence_length=self.decoder_inputs_length_train,
                    time_major=False,
                    name='training_helper')

                training_decoder = seq2seq.BasicDecoder(
                    cell=self.decoder_cell,
                    helper=training_helper,
                    initial_state=self.decoder_initial_state,
                    output_layer=output_layer)

                #Maximum decoder time_steps in current batch
                max_decoder_length = tf.reduce_max(
                    self.decoder_inputs_length_train)

                # decoder_outputs_train: BasicDecoderOutput
                #                        namedtuple(rnn_outputs, sample_id)
                # decoder_outputs_train.rnn_output: [batch_size, max_time_step + 1, num_decoder_symbols] if output_time_major=False
                #                                   [max_time_step + 1, batch_size, num_decoder_symbols] if output_time_major=True
                # decoder_outputs_train.sample_id: [batch_size], tf.int32
                (self.decoder_outputs_train, self.decoder_last_state_train,
                 self.decoder_outputs_length_train) = (seq2seq.dynamic_decode(
                     decoder=training_decoder,
                     output_time_major=False,
                     impute_finished=True,
                     maximum_iterations=max_decoder_length))
                # More efficient to do the projection on the batch-time-concatenated tensor
                # logits_train: [batch_size, max_time_step + 1, num_decoder_symbols]
                # self.decoder_logits_train = output_layer(self.decoder_outputs_train.rnn_output)
                self.decoder_logits_train = tf.identity(
                    self.decoder_outputs_train.rnn_output)

                # Use argmax to extract decoder symbols to emit
                self.decoder_pred_train = tf.argmax(self.decoder_logits_train,
                                                    axis=-1,
                                                    name='decoder_pre_train')

                # masks: masking for valid and padded time steps, [batch_size, max_time_step + 1]
                masks = tf.sequence_mask(
                    lengths=self.decoder_inputs_length_train,
                    maxlen=max_decoder_length,
                    dtype=self.dtype,
                    name='masks')

                self.loss = seq2seq.sequence_loss(
                    logits=self.decoder_logits_train,
                    targets=self.decoder_targets_train,
                    weights=masks,
                    average_across_timesteps=True,
                    average_across_batch=True)

                # Training summary for the current batch_loss
                tf.summary.scalar('loss', self.loss)
            elif self.mode == 'decode':
                # Start_tokens: [batch_size,] `int32` vector
                start_token = tf.ones([
                    self.batch_size,
                ], tf.int32) * data_utils.GO_ID
                end_token = data_utils.EOS_ID

                def embed_and_input_proj(inputs):
                    return input_layer(
                        tf.nn.embedding_lookup(self.embeddings, inputs))

                if not self.use_beamsearch_decode:
                    decoding_helper = seq2seq.GreedyEmbeddingHelper(
                        start_tokens=start_token,
                        end_token=end_token,
                        embedding=embed_and_input_proj)
                    inference_decoder = seq2seq.BasicDecoder(
                        cell=self.decoder_cell,
                        helper=decoding_helper,
                        initial_state=self.decoder_initial_state,
                        output_layer=output_layer)
                else:
                    inference_decoder = seq2seq.BeamSearchDecoder(
                        cell=self.decoder_cell,
                        embedding=embed_and_input_proj,
                        start_tokens=start_token,
                        end_token=end_token,
                        initial_state=self.decoder_initial_state,
                        beam_width=self.beam_with,
                        output_layer=output_layer)

                    # For GreedyDecoder, return
                    # decoder_outputs_decode: BasicDecoderOutput instance
                    #                         namedtuple(rnn_outputs, sample_id)
                    # decoder_outputs_decode.rnn_output: [batch_size, max_time_step, num_decoder_symbols] 	if output_time_major=False
                    #                                    [max_time_step, batch_size, num_decoder_symbols] 	if output_time_major=True
                    # decoder_outputs_decode.sample_id: [batch_size, max_time_step], tf.int32		if output_time_major=False
                    #                                   [max_time_step, batch_size], tf.int32               if output_time_major=True

                    # For BeamSearchDecoder, return
                    # decoder_outputs_decode: FinalBeamSearchDecoderOutput instance
                    #                         namedtuple(predicted_ids, beam_search_decoder_output)
                    # decoder_outputs_decode.predicted_ids: [batch_size, max_time_step, beam_width] if output_time_major=False
                    #                                       [max_time_step, batch_size, beam_width] if output_time_major=True
                    # decoder_outputs_decode.beam_search_decoder_output: BeamSearchDecoderOutput instance
                    #                                                    namedtuple(scores, predicted_ids, parent_ids)

                (self.decoder_outputs_decode, self.decoder_last_state_decode,
                 self.decoder_outputs_length_decode) = (seq2seq.dynamic_decode(
                     decoder=inference_decoder,
                     output_time_major=False,
                     maximum_iterations=self.config.max_decode_step))

                if not self.use_beamsearch_decode:
                    # decoder_outputs_decode.sample_id: [batch_size, max_time_step]
                    # Or use argmax to find decoder symbols to emit:
                    # self.decoder_pred_decode = tf.argmax(self.decoder_outputs_decode.rnn_output,
                    #                                      axis=-1, name='decoder_pred_decode')

                    # Here, we use expand_dims to be compatible with the result of the beamsearch decoder
                    # decoder_pred_decode: [batch_size, max_time_step, 1] (output_major=False)
                    self.decoder_pred_decode = tf.expand_dims(
                        self.decoder_outputs_decode.sample_id, -1)
                else:
                    # Use beam search to approximately find the most likely translation
                    # decoder_pred_decode: [batch_size, max_time_step, beam_width] (output_major=False)
                    self.decoder_pred_decode = self.decoder_outputs_decode.predicted_ids
    def _build_decoder(self, encoder_outputs, encoder_state, hparams):
        """Build and run a RNN decoder with a final projection layer.
    
        Args:
          encoder_outputs: The outputs of encoder for every time step.
          encoder_state: The final state of the encoder.
          hparams: The Hyperparameters configurations.
    
        Returns:
          A tuple of final logits and final decoder state:
            logits: size [time, batch_size, vocab_size] when time_major=True.
        """
        tgt_sos_id = tf.cast(
            self.tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32)
        tgt_eos_id = tf.cast(
            self.tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32)
        iterator = self.iterator

        # maximum_iteration: The maximum decoding steps.
        maximum_iterations = self._get_infer_maximum_iterations(
            hparams, iterator.source_sequence_length)

        ## Decoder.
        with tf.variable_scope("decoder") as decoder_scope:
            cell, decoder_initial_state = self._build_decoder_cell(
                hparams, encoder_outputs, encoder_state,
                iterator.source_sequence_length)

            ## Train or eval
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                # decoder_emp_inp: [max_time, batch_size, num_units]
                target_input = iterator.target_input
                if self.time_major:
                    target_input = tf.transpose(target_input)
                decoder_emb_inp = tf.nn.embedding_lookup(
                    self.embedding_decoder, target_input)

                # Helper
                helper = seq2seq.TrainingHelper(
                    decoder_emb_inp,
                    iterator.target_sequence_length,
                    time_major=self.time_major)

                # Decoder
                my_decoder = seq2seq.BasicDecoder(
                    cell,
                    helper,
                    decoder_initial_state,
                )

                # Dynamic decoding
                outputs, final_context_state, \
                _ = seq2seq.dynamic_decode(
                        my_decoder,
                        output_time_major=self.time_major,
                        swap_memory=True,
                        scope=decoder_scope)

                sample_id = outputs.sample_id

                # Note: there's a subtle difference here between train and
                # inference.
                # We could have set output_layer when create my_decoder
                #   and shared more code between train and inference.
                # We chose to apply the output_layer to all timesteps for speed:
                #   10% improvements for small models & 20% for larger ones.
                # If memory is a concern, we should apply output_layer per
                # timestep.
                logits = self.output_layer(outputs.rnn_output)

            ## Inference
            else:
                beam_width = hparams.beam_width
                length_penalty_weight = hparams.length_penalty_weight
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id

                if beam_width > 0:
                    my_decoder = seq2seq.BeamSearchDecoder(
                        cell=cell,
                        embedding=self.embedding_decoder,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=decoder_initial_state,
                        beam_width=beam_width,
                        output_layer=self.output_layer,
                        length_penalty_weight=length_penalty_weight)
                else:
                    # Helper
                    sampling_temperature = hparams.sampling_temperature
                    if sampling_temperature > 0.0:
                        helper = seq2seq.SampleEmbeddingHelper(
                            self.embedding_decoder,
                            start_tokens,
                            end_token,
                            softmax_temperature=sampling_temperature,
                            seed=hparams.random_seed)
                    else:
                        helper = seq2seq.GreedyEmbeddingHelper(
                            self.embedding_decoder, start_tokens, end_token)

                    # Decoder
                    my_decoder = seq2seq.BasicDecoder(
                        cell,
                        helper,
                        decoder_initial_state,
                        output_layer=self.output_layer
                        # applied per timestep
                    )

                # Dynamic decoding
                outputs, final_context_state, \
                _ = seq2seq.dynamic_decode(
                        my_decoder,
                        maximum_iterations=maximum_iterations,
                        output_time_major=self.time_major,
                        swap_memory=True,
                        scope=decoder_scope)

                if beam_width > 0:
                    # TODO rerank here
                    logits = tf.no_op()
                    sample_id = outputs.predicted_ids
                else:
                    logits = outputs.rnn_output
                    sample_id = outputs.sample_id

        return logits, sample_id, final_context_state
Beispiel #23
0
    def _build_forward(self):
        config = self.config
        N, M, JX, JQ, VW, VC, d, W = \
            config.batch_size, config.max_num_sents, config.max_sent_size, \
            config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \
            config.max_word_size
        beam_width = config.beam_width
        GO_TOKEN = 0
        EOS_TOKEN = 1

        JX = tf.shape(self.x)[2]
        JQ = tf.shape(self.q)[1]
        M = tf.shape(self.x)[1]
        dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size

        with tf.variable_scope("emb"):
            if config.use_char_emb:
                with tf.variable_scope("emb_var"), tf.device("/cpu:0"):
                    char_emb_mat = tf.get_variable("char_emb_mat",
                                                   shape=[VC, dc],
                                                   dtype='float')

                with tf.variable_scope("char"):
                    Acx = tf.nn.embedding_lookup(char_emb_mat,
                                                 self.cx)  # [N, M, JX, W, dc]
                    Acq = tf.nn.embedding_lookup(char_emb_mat,
                                                 self.cq)  # [N, JQ, W, dc]
                    Acx = tf.reshape(Acx, [-1, JX, W, dc])
                    Acq = tf.reshape(Acq, [-1, JQ, W, dc])

                    filter_sizes = list(
                        map(int, config.out_channel_dims.split(',')))
                    heights = list(map(int, config.filter_heights.split(',')))
                    assert sum(filter_sizes) == dco, (filter_sizes, dco)
                    with tf.variable_scope("conv"):
                        xx = multi_conv1d(Acx,
                                          filter_sizes,
                                          heights,
                                          "VALID",
                                          self.is_train,
                                          config.keep_prob,
                                          scope="xx")
                        if config.share_cnn_weights:
                            tf.get_variable_scope().reuse_variables()
                            qq = multi_conv1d(Acq,
                                              filter_sizes,
                                              heights,
                                              "VALID",
                                              self.is_train,
                                              config.keep_prob,
                                              scope="xx")
                        else:
                            qq = multi_conv1d(Acq,
                                              filter_sizes,
                                              heights,
                                              "VALID",
                                              self.is_train,
                                              config.keep_prob,
                                              scope="qq")
                        xx = tf.reshape(xx, [-1, M, JX, dco])
                        qq = tf.reshape(qq, [-1, JQ, dco])

            if config.use_word_emb:
                with tf.variable_scope("emb_var"), tf.device("/cpu:0"):
                    if config.mode == 'train':
                        word_emb_mat = tf.get_variable(
                            "word_emb_mat",
                            dtype='float',
                            shape=[VW, dw],
                            initializer=get_initializer(config.emb_mat),
                            trainable=True)
                    else:
                        word_emb_mat = tf.get_variable("word_emb_mat",
                                                       shape=[VW, dw],
                                                       dtype='float')
                    if config.use_glove_for_unk:
                        word_emb_mat = tf.concat(
                            axis=0, values=[word_emb_mat, self.new_emb_mat])
                with tf.name_scope("word"):
                    Ax = tf.nn.embedding_lookup(word_emb_mat,
                                                self.x)  # [N, M, JX, d]
                    Aq = tf.nn.embedding_lookup(word_emb_mat,
                                                self.q)  # [N, JQ, d]
                    self.tensor_dict['x'] = Ax
                    self.tensor_dict['q'] = Aq
                if config.use_char_emb:
                    xx = tf.concat(axis=3, values=[xx, Ax])  # [N, M, JX, di]
                    qq = tf.concat(axis=2, values=[qq, Aq])  # [N, JQ, di]
                else:
                    xx = Ax
                    qq = Aq

        # highway network
        if config.highway:
            with tf.variable_scope("highway"):
                xx = highway_network(xx,
                                     config.highway_num_layers,
                                     True,
                                     wd=config.wd,
                                     is_train=self.is_train)
                tf.get_variable_scope().reuse_variables()
                qq = highway_network(qq,
                                     config.highway_num_layers,
                                     True,
                                     wd=config.wd,
                                     is_train=self.is_train)

        self.tensor_dict['xx'] = xx
        self.tensor_dict['qq'] = qq

        cell_fw = BasicLSTMCell(d, state_is_tuple=True)
        cell_bw = BasicLSTMCell(d, state_is_tuple=True)
        d_cell_fw = SwitchableDropoutWrapper(
            cell_fw, self.is_train, input_keep_prob=config.input_keep_prob)
        d_cell_bw = SwitchableDropoutWrapper(
            cell_bw, self.is_train, input_keep_prob=config.input_keep_prob)
        cell2_fw = BasicLSTMCell(d, state_is_tuple=True)
        cell2_bw = BasicLSTMCell(d, state_is_tuple=True)
        d_cell2_fw = SwitchableDropoutWrapper(
            cell2_fw, self.is_train, input_keep_prob=config.input_keep_prob)
        d_cell2_bw = SwitchableDropoutWrapper(
            cell2_bw, self.is_train, input_keep_prob=config.input_keep_prob)
        cell3_fw = BasicLSTMCell(d, state_is_tuple=True)
        cell3_bw = BasicLSTMCell(d, state_is_tuple=True)
        d_cell3_fw = SwitchableDropoutWrapper(
            cell3_fw, self.is_train, input_keep_prob=config.input_keep_prob)
        d_cell3_bw = SwitchableDropoutWrapper(
            cell3_bw, self.is_train, input_keep_prob=config.input_keep_prob)
        cell4_fw = BasicLSTMCell(d, state_is_tuple=True)
        cell4_bw = BasicLSTMCell(d, state_is_tuple=True)
        d_cell4_fw = SwitchableDropoutWrapper(
            cell4_fw, self.is_train, input_keep_prob=config.input_keep_prob)
        d_cell4_bw = SwitchableDropoutWrapper(
            cell4_bw, self.is_train, input_keep_prob=config.input_keep_prob)
        x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2)  # [N, M]
        q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1)  # [N]

        with tf.variable_scope("prepro"):
            (fw_u, bw_u), ((_, fw_u_f), (_,
                                         bw_u_f)) = bidirectional_dynamic_rnn(
                                             d_cell_fw,
                                             d_cell_bw,
                                             qq,
                                             q_len,
                                             dtype='float',
                                             scope='u1')  # [N, J, d], [N, d]
            u = tf.concat(axis=2, values=[fw_u, bw_u])
            if config.share_lstm_weights:
                tf.get_variable_scope().reuse_variables()
                (fw_h, bw_h), ((_, fw_h_f),
                               (_, bw_h_f)) = bidirectional_dynamic_rnn(
                                   cell_fw,
                                   cell_bw,
                                   xx,
                                   x_len,
                                   dtype='float',
                                   scope='u1')  # [N, M, JX, 2d]
                h = tf.concat(axis=3, values=[fw_h, bw_h])  # [N, M, JX, 2d]
            else:
                (fw_h, bw_h), ((_, fw_h_f),
                               (_, bw_h_f)) = bidirectional_dynamic_rnn(
                                   cell_fw,
                                   cell_bw,
                                   xx,
                                   x_len,
                                   dtype='float',
                                   scope='h1')  # [N, M, JX, 2d]
                h = tf.concat(axis=3, values=[fw_h, bw_h])  # [N, M, JX, 2d]
            self.tensor_dict['u'] = u
            self.tensor_dict['h'] = h

        with tf.variable_scope("main"):
            if config.dynamic_att:
                p0 = h
                u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]),
                               [N * M, JQ, 2 * d])
                q_mask = tf.reshape(
                    tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]),
                    [N * M, JQ])
                first_cell_fw = AttentionCell(
                    cell2_fw,
                    u,
                    mask=q_mask,
                    mapper='sim',
                    input_keep_prob=self.config.input_keep_prob,
                    is_train=self.is_train)
                first_cell_bw = AttentionCell(
                    cell2_bw,
                    u,
                    mask=q_mask,
                    mapper='sim',
                    input_keep_prob=self.config.input_keep_prob,
                    is_train=self.is_train)
                second_cell_fw = AttentionCell(
                    cell3_fw,
                    u,
                    mask=q_mask,
                    mapper='sim',
                    input_keep_prob=self.config.input_keep_prob,
                    is_train=self.is_train)
                second_cell_bw = AttentionCell(
                    cell3_bw,
                    u,
                    mask=q_mask,
                    mapper='sim',
                    input_keep_prob=self.config.input_keep_prob,
                    is_train=self.is_train)
            else:
                p0 = attention_layer(config,
                                     self.is_train,
                                     h,
                                     u,
                                     h_mask=self.x_mask,
                                     u_mask=self.q_mask,
                                     scope="p0",
                                     tensor_dict=self.tensor_dict)
                first_cell_fw = d_cell2_fw
                second_cell_fw = d_cell3_fw
                first_cell_bw = d_cell2_bw
                second_cell_bw = d_cell3_bw

            (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn(
                first_cell_fw,
                first_cell_bw,
                p0,
                x_len,
                dtype='float',
                scope='g0')  # [N, M, JX, 2d]
            g0 = tf.concat(axis=3, values=[fw_g0, bw_g0])
            (fw_g1, bw_g1), _ = bidirectional_dynamic_rnn(
                second_cell_fw,
                second_cell_bw,
                g0,
                x_len,
                dtype='float',
                scope='g1')  # [N, M, JX, 2d]
            g1 = tf.concat(axis=3, values=[fw_g1, bw_g1])

            logits = get_logits([g1, p0],
                                d,
                                True,
                                wd=config.wd,
                                input_keep_prob=config.input_keep_prob,
                                mask=self.x_mask,
                                is_train=self.is_train,
                                func=config.answer_func,
                                scope='logits1')
            a1i = softsel(tf.reshape(g1, [N, M * JX, 2 * d]),
                          tf.reshape(logits, [N, M * JX]))
            a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1),
                          [1, M, JX, 1])

            (fw_g2, bw_g2), _ = bidirectional_dynamic_rnn(
                d_cell4_fw,
                d_cell4_bw,
                tf.concat(axis=3, values=[p0, g1, a1i, g1 * a1i]),
                x_len,
                dtype='float',
                scope='g2')  # [N, M, JX, 2d]
            g2 = tf.concat(axis=3, values=[fw_g2, bw_g2])
            logits2 = get_logits([g2, p0],
                                 d,
                                 True,
                                 wd=config.wd,
                                 input_keep_prob=config.input_keep_prob,
                                 mask=self.x_mask,
                                 is_train=self.is_train,
                                 func=config.answer_func,
                                 scope='logits2')

            flat_logits = tf.reshape(logits, [-1, M * JX])
            flat_yp = tf.nn.softmax(flat_logits)  # [-1, M*JX]
            flat_logits2 = tf.reshape(logits2, [-1, M * JX])
            flat_yp2 = tf.nn.softmax(flat_logits2)

            if config.na:
                na_bias = tf.get_variable("na_bias", shape=[], dtype='float')
                na_bias_tiled = tf.tile(tf.reshape(na_bias, [1, 1]),
                                        [N, 1])  # [N, 1]
                concat_flat_logits = tf.concat(
                    axis=1, values=[na_bias_tiled, flat_logits])
                concat_flat_yp = tf.nn.softmax(concat_flat_logits)
                na_prob = tf.squeeze(tf.slice(concat_flat_yp, [0, 0], [-1, 1]),
                                     [1])
                flat_yp = tf.slice(concat_flat_yp, [0, 1], [-1, -1])

                concat_flat_logits2 = tf.concat(
                    axis=1, values=[na_bias_tiled, flat_logits2])
                concat_flat_yp2 = tf.nn.softmax(concat_flat_logits2)
                na_prob2 = tf.squeeze(
                    tf.slice(concat_flat_yp2, [0, 0], [-1, 1]), [1])  # [N]
                flat_yp2 = tf.slice(concat_flat_yp2, [0, 1], [-1, -1])

                self.concat_logits = concat_flat_logits
                self.concat_logits2 = concat_flat_logits2
                self.na_prob = na_prob * na_prob2

            yp = tf.reshape(flat_yp, [-1, M, JX])
            yp2 = tf.reshape(flat_yp2, [-1, M, JX])
            wyp = tf.nn.sigmoid(logits2)

            self.tensor_dict['g1'] = g1
            self.tensor_dict['g2'] = g2

            self.logits = flat_logits
            self.logits2 = flat_logits2
            self.yp = yp
            self.yp2 = yp2
            self.wyp = wyp

        with tf.variable_scope("q_gen"):
            # Question Generation Using (Paragraph & Predicted Ans Pos)
            NM = config.max_num_sents * config.batch_size

            # Separated encoder
            #ss = tf.reshape(xx, (-1, JX, dw+dco))

            q_worthy = tf.reduce_sum(
                tf.to_int32(self.y), axis=2
            )  # so we get probability distribution of answer-likely. (N, M)
            q_worthy = tf.expand_dims(tf.to_int32(tf.argmax(q_worthy, axis=1)),
                                      axis=1)  # (N) -> (N, 1)
            q_worthy = tf.concat([
                tf.expand_dims(tf.range(0, N, dtype=tf.int32), axis=1),
                q_worthy
            ],
                                 axis=1)
            # example : [0, 9], [1, 11], [2, 8], [3, 5], [4, 0], [5, 1] ...

            ss = tf.gather_nd(xx, q_worthy)
            syp = tf.expand_dims(tf.gather_nd(yp, q_worthy), axis=-1)
            syp2 = tf.expand_dims(tf.gather_nd(yp2, q_worthy), axis=-1)
            ss_with_ans = tf.concat([ss, syp, syp2], axis=2)

            qg_dim = 600
            cell_fw, cell_bw = rnn.DropoutWrapper(rnn.GRUCell(qg_dim), input_keep_prob=config.input_keep_prob), \
                               rnn.DropoutWrapper(rnn.GRUCell(qg_dim), input_keep_prob=config.input_keep_prob)
            s_outputs, s_states = tf.nn.bidirectional_dynamic_rnn(
                cell_fw, cell_bw, ss_with_ans, dtype=tf.float32)
            s_outputs = tf.concat(s_outputs, axis=2)
            s_states = tf.concat(s_states, axis=1)

            start_tokens = tf.zeros([N], dtype=tf.int32)
            self.inp_q_with_GO = tf.concat(
                [tf.expand_dims(start_tokens, axis=1), self.q], axis=1)
            # supervise if mode is train
            if config.mode == "train":
                emb_q = tf.nn.embedding_lookup(params=word_emb_mat,
                                               ids=self.inp_q_with_GO)
                #emb_q = tf.reshape(tf.tile(tf.expand_dims(emb_q, axis=1), [1, M, 1, 1]), (NM, JQ+1, dw))
                train_helper = seq2seq.TrainingHelper(emb_q, [JQ] * N)
            else:
                s_outputs = seq2seq.tile_batch(s_outputs,
                                               multiplier=beam_width)
                s_states = seq2seq.tile_batch(s_states, multiplier=beam_width)

            cell = rnn.DropoutWrapper(rnn.GRUCell(num_units=qg_dim * 2),
                                      input_keep_prob=config.input_keep_prob)
            attention_mechanism = seq2seq.BahdanauAttention(num_units=qg_dim *
                                                            2,
                                                            memory=s_outputs)
            attn_cell = seq2seq.AttentionWrapper(cell,
                                                 attention_mechanism,
                                                 attention_layer_size=qg_dim *
                                                 2,
                                                 output_attention=True,
                                                 alignment_history=False)
            total_glove_vocab_size = 78878  #72686
            out_cell = rnn.OutputProjectionWrapper(attn_cell,
                                                   VW + total_glove_vocab_size)
            if config.mode == "train":
                decoder_initial_states = out_cell.zero_state(
                    batch_size=N, dtype=tf.float32).clone(cell_state=s_states)
                decoder = seq2seq.BasicDecoder(
                    cell=out_cell,
                    helper=train_helper,
                    initial_state=decoder_initial_states)
            else:
                decoder_initial_states = out_cell.zero_state(
                    batch_size=N * beam_width,
                    dtype=tf.float32).clone(cell_state=s_states)
                decoder = seq2seq.BeamSearchDecoder(
                    cell=out_cell,
                    embedding=word_emb_mat,
                    start_tokens=start_tokens,
                    end_token=EOS_TOKEN,
                    initial_state=decoder_initial_states,
                    beam_width=beam_width,
                    length_penalty_weight=0.0)
            outputs = seq2seq.dynamic_decode(decoder=decoder,
                                             maximum_iterations=JQ)
            if config.mode == "train":
                gen_q = outputs[0].sample_id
                gen_q_prob = outputs[0].rnn_output
                gen_q_states = outputs[1]
            else:
                gen_q = outputs[0].predicted_ids[:, :, 0]
                gen_q_prob = tf.nn.embedding_lookup(
                    params=word_emb_mat, ids=outputs[0].predicted_ids[:, :, 0])
                gen_q_states = outputs[1]

            self.gen_q = gen_q
            self.gen_q_prob = gen_q_prob
            self.gen_q_states = gen_q_states
def rbmE_gruD(mode, features, labels, params):
    inp = features["x"]

    if state != "Infering":
        ids = features["ids"]
        weights = features["weights"]

    batch_size = params["batch_size"]

    #Encoder
    enc_cell = rnn.NASCell(num_units=NUM_UNITS)
    enc_out, enc_state = tf.nn.dynamic_rnn(enc_cell,
                                           inp,
                                           time_major=False,
                                           dtype=tf.float32)

    #Decoder
    cell = rnn.NASCell(num_units=NUM_UNITS)

    _, embeddings = load_processed_embeddings(sess=tf.InteractiveSession())
    out_lengths = tf.constant(seq_len, shape=[batch_size])
    if state != "Infering":
        #sampling method for training
        train_helper = seq2seq.TrainingHelper(labels,
                                              out_lengths,
                                              time_major=False)
        '''
        train_helper=seq2seq.ScheduledEmbeddingTrainingHelper(inputs=labels,
                                                              sequence_length=out_lengths,
                                                              embedding=embeddings,
                                                              sampling_probability=probs)
        '''
    #sampling method for evaluation
    start_tokens = tf.zeros([batch_size], dtype=tf.int32)
    infer_helper = seq2seq.GreedyEmbeddingHelper(embedding=embeddings,
                                                 start_tokens=start_tokens,
                                                 end_token=END)
    #infer_helper = seq2seq.SampleEmbeddingHelper(embeddings,start_tokens=start_tokens,end_token=END)
    #infer_helper=seq2seq.ScheduledEmbeddingTrainingHelper(inputs=inp,sequence_length=out_lengths,embedding=embeddings,sampling_probability=1.0)
    projection_layer = layers_core.Dense(vocab_size, use_bias=False)

    def decode(helper):
        decoder = seq2seq.BasicDecoder(cell=cell,
                                       helper=helper,
                                       initial_state=enc_state,
                                       output_layer=projection_layer)
        #decoder.tracks_own_finished=True
        (dec_outputs, _,
         _) = seq2seq.dynamic_decode(decoder, maximum_iterations=seq_len)
        #(dec_outputs,_,_) = seq2seq.dynamic_decode(decoder)
        dec_ids = dec_outputs.sample_id
        logits = dec_outputs.rnn_output
        return dec_ids, logits

    #equalize logits, labels and weight lengths incase of early finish in decoder
    def norm_logits_loss(logts, ids, weights):
        current_ts = tf.to_int32(
            tf.minimum(tf.shape(ids)[1],
                       tf.shape(logts)[1]))
        logts = tf.slice(logts, begin=[0, 0, 0], size=[-1, current_ts, -1])
        ids = tf.slice(ids, begin=[0, 0], size=[-1, current_ts])
        weights = tf.slice(weights, begin=[0, 0], size=[-1, current_ts])
        return logts, ids, weights

    #training mode
    if state == "Training":
        dec_ids, logits = decode(train_helper)
        # some sample_id are overwritten with '-1's
        #dec_ids = tf.argmax(logits, axis=2)
        tf.identity(dec_ids, name="predictions")
        logits, ids, weights = norm_logits_loss(logits, ids, weights)
        loss = tf.contrib.seq2seq.sequence_loss(logits, ids, weights=weights)
        learning_rate = 0.001  #0.0001

        tf.identity(learning_rate, name="learning_rate")

    #evaluation mode
    if state == "Evaluating" or state == "Testing":
        eval_dec_ids, eval_logits = decode(infer_helper)
        #eval_dec_ids = tf.argmax(eval_logits, axis=2)
        tf.identity(eval_dec_ids, name="predictions")

        #equalize logits, labels and weight lengths incase of early finish in decoder
        eval_logits, ids, weights = norm_logits_loss(eval_logits, ids, weights)
        '''
        current_ts = tf.to_int32(tf.minimum(tf.shape(ids)[1], tf.shape(eval_logits)[1]))
        ids = tf.slice(ids, begin=[0, 0], size=[-1, current_ts])
        weights = tf.slice(weights, begin=[0, 0], size=[-1, current_ts])
        #mask_ = tf.sequence_mask(lengths=target_sequence_length, maxlen=current_ts, dtype=eval_logits.dtype)
        eval_logits = tf.slice(eval_logits, begin=[0,0,0], size=[-1, current_ts, -1])       
        '''
        eval_loss = tf.contrib.seq2seq.sequence_loss(eval_logits,
                                                     ids,
                                                     weights=weights)

    #beamSearch decoder
    init_state = tf.contrib.seq2seq.tile_batch(enc_state, multiplier=5)
    beamSearch_decoder = seq2seq.BeamSearchDecoder(
        cell,
        embeddings,
        start_tokens,
        end_token=END,
        initial_state=init_state,
        beam_width=5,
        output_layer=projection_layer)
    (infer_outputs, _, _) = seq2seq.dynamic_decode(beamSearch_decoder,
                                                   maximum_iterations=seq_len)
    infer_ids = infer_outputs.predicted_ids
    infer_probs = infer_outputs.beam_search_decoder_output.scores
    infer_probs = tf.reduce_prod(infer_probs, axis=1)
    infer_pos = tf.argmax(infer_probs, axis=1)
    infers = {"ids": infer_ids, "pos": infer_pos}

    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = layers.optimize_loss(loss,
                                        tf.train.get_global_step(),
                                        optimizer='Adam',
                                        learning_rate=learning_rate,
                                        clip_gradients=5.0)

        spec = tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=dec_ids,
                                          loss=loss,
                                          train_op=train_op)
    #evaluation mode
    elif mode == tf.estimator.ModeKeys.EVAL:
        spec = tf.estimator.EstimatorSpec(mode=mode,
                                          loss=eval_loss,
                                          predictions=eval_dec_ids)
    else:
        spec = tf.estimator.EstimatorSpec(mode=mode, predictions=infers)
    return spec
Beispiel #25
0
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 dropout,
                 num_layers,
                 max_gradient_norm,
                 batch_size,
                 learning_rate,
                 lr_decay_factor,
                 max_target_length,
                 max_source_length,
                 decoder_mode=False):
        '''
        vocab_size: number of vocab tokens
        buckets: buckets of max sequence lengths
        hidden_size: dimension of hidden layers
        num_layers: number of hidden layers
        max_gradient_norm: maximum gradient magnitude
        batch_size: number of training examples fed to network at once
        learning_rate: starting learning rate of network
        lr_decay_factor: amount by which to decay learning rate
        num_samples: number of samples for sampled softmax
        decoder_mode: Whether to build backpass nodes or not
        '''
        GO_ID = config.GO_ID
        EOS_ID = config.EOS_ID
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.global_step = tf.Variable(0, trainable=False)
        self.learning_rate = learning_rate
        self.encoder_inputs = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='encoder_inputs')
        self.source_lengths = tf.placeholder(shape=(None, ),
                                             dtype=tf.int32,
                                             name='source_lengths')

        self.decoder_targets = tf.placeholder(shape=(None, None),
                                              dtype=tf.int32,
                                              name='decoder_targets')
        self.target_lengths = tf.placeholder(shape=(None, ),
                                             dtype=tf.int32,
                                             name="target_lengths")

        with tf.variable_scope('embeddings') as scope:
            embeddings = tf.Variable(tf.random_uniform(
                [vocab_size, hidden_size], -1.0, 1.0),
                                     dtype=tf.float32)
            encoder_inputs_embedded = tf.nn.embedding_lookup(
                embeddings, self.encoder_inputs)
            targets_embedding = tf.nn.embedding_lookup(embeddings,
                                                       self.decoder_targets)

        with tf.variable_scope('encoder') as scope:

            encoder_cell = rnn.LSTMCell(hidden_size)
            encoder_cell = rnn.DropoutWrapper(encoder_cell,
                                              input_keep_prob=dropout)
            encoder_cell = tf.nn.rnn_cell.MultiRNNCell(
                [encoder_cell for _ in range(num_layers)], state_is_tuple=True)

            encoder_outputs, encoder_state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=encoder_cell,
                cell_bw=encoder_cell,
                sequence_length=self.source_lengths,
                inputs=encoder_inputs_embedded,
                dtype=tf.float32,
                time_major=False)  #BiLSTM encoder
            encoder_output = encoder_outputs[0]
            encoder_outputs = tf.concat(encoder_outputs, 2)

        with tf.variable_scope('decoder') as scope:
            decoder_cell = rnn.LSTMCell(hidden_size)
            decoder_cell = rnn.DropoutWrapper(decoder_cell,
                                              input_keep_prob=dropout)

            decoder_cell = tf.nn.rnn_cell.MultiRNNCell(
                [decoder_cell for _ in range(num_layers)], state_is_tuple=True)

            #TODO add attention
            #attention_mechanism= seq2seq.BahdanauAttention(num_units=hidden_size,memory=encoder_outputs)

            #decoder_cell = seq2seq.AttentionWrapper(cell=decoder_cell,
            #                                       attention_mechanism=)

            attn_mech = seq2seq.BahdanauAttention(
                num_units=hidden_size,  #depth of query mechanism
                memory=encoder_output,  #out of RNN hidden states
                memory_sequence_length=self.source_lengths,
                name='BahdanauAttentiion')
            attn_cell = seq2seq.AttentionWrapper(
                cell=decoder_cell,  #same as encoder
                attention_mechanism=attn_mech,
                attention_layer_size=hidden_size,  #depth of attention tensor
                name='attention_wrapper')  #attention layer

        if decoder_mode:
            beam_width = 1

            attn_zero = attn_cell.zero_state(batch_size=(batch_size *
                                                         beam_width),
                                             dtype=tf.float32)
            init_state = attn_zero.clone(cell_state=encoder_state)
            decoder = seq2seq.BeamSearchDecoder(
                cell=attn_cell,
                embedding=embeddings,
                start_tokens=tf.tile([GO_ID], [1]),
                end_token=EOS_ID,
                initial_state=init_state,
                beam_width=beam_width,
                output_layer=Dense(vocab_size))  #BeamSearch in Decoder
            final_outputs, final_state, final_sequence_lengths =\
                            seq2seq.dynamic_decode(decoder=decoder)
            self.logits = final_outputs.predicted_ids
        else:
            helper = seq2seq.TrainingHelper(
                inputs=targets_embedding, sequence_length=self.target_lengths)
            decoder = seq2seq.BasicDecoder(
                cell=attn_cell,
                helper=helper,
                #initial_state=attn_cell.zero_state(batch_size, tf.float32),
                initial_state=attn_cell.zero_state(
                    batch_size, tf.float32).clone(cell_state=encoder_state[0]),
                output_layer=Dense(vocab_size))
            final_outputs, final_state, final_sequence_lengths =\
                            seq2seq.dynamic_decode(decoder=decoder)

            self.logits = final_outputs.rnn_output

        if not decoder_mode:
            with tf.variable_scope("loss") as scope:
                #have to pad logits, dynamic decode produces results not consistent
                #in shape with targets
                pad_size = self.max_target_length - tf.reduce_max(
                    final_sequence_lengths)
                self.logits = tf.pad(self.logits,
                                     [[0, 0], [0, pad_size], [0, 0]])

                weights = tf.sequence_mask(lengths=final_sequence_lengths,
                                           maxlen=self.max_target_length,
                                           dtype=tf.float32,
                                           name='weights')

                x_entropy_loss = seq2seq.sequence_loss(
                    logits=self.logits,
                    targets=self.decoder_targets,
                    weights=weights)  #cross-entropy loss function

                self.loss = tf.reduce_mean(x_entropy_loss)

            optimizer = tf.train.AdamOptimizer()  #Adam optimization algorithm
            gradients = optimizer.compute_gradients(x_entropy_loss)
            capped_grads = [(tf.clip_by_value(grad, -max_gradient_norm,
                                              max_gradient_norm), var)
                            for grad, var in gradients]
            self.train_op = optimizer.apply_gradients(
                capped_grads, global_step=self.global_step)
            self.saver = tf.train.Saver(tf.global_variables())
Beispiel #26
0
    def build_model(self):
        '''
        建立seq2seq模型
        '''
        self.query_input = tf.placeholder(tf.int32, [None, None])
        self.query_length = tf.placeholder(tf.int32, [None])

        self.answer_input = tf.placeholder(tf.int32, [None, None])
        self.answer_target = tf.placeholder(tf.int32, [None, None])
        self.answer_length = tf.placeholder(tf.int32, [None])
        self.batch_size = array_ops.shape(self.query_input)[0]

        if self.mode == "train":
            self.max_decode_step = tf.reduce_max(self.answer_length)
            self.sequence_mask = tf.sequence_mask(self.answer_length,
                                                  self.max_decode_step,
                                                  dtype=tf.float32)
        elif self.mode == "decode":
            self.max_decode_step = tf.reduce_max(self.query_length) * 10

        # input and output embedding
        self.embeddings_matrix = tf.Variable(tf.random_uniform(
            [self.vocab_size, EMBEDDING_SIZE], -1.0, 1.0),
                                             dtype=tf.float32)

        self.query_embeddings = tf.nn.embedding_lookup(self.embeddings_matrix,
                                                       self.query_input)
        self.answer_embeddings = tf.nn.embedding_lookup(
            self.embeddings_matrix, self.answer_input)

        # encoder process
        self.encoder_outputs, self.encoder_state = tf.nn.dynamic_rnn(
            rnn.BasicLSTMCell(ENCODER_HIDDEN_SIZE),
            self.query_embeddings,
            sequence_length=self.query_length,
            dtype=tf.float32)

        # 通过beam search 加工出一批临时变量,后续复用
        batch_size, encoder_outputs, encoder_state, encoder_length = (
            self.batch_size, self.encoder_outputs, self.encoder_state,
            self.query_length)

        if self.mode == "decode":
            batch_size = batch_size * BEAM_WIDTH
            encoder_outputs = seq2seq.tile_batch(t=self.encoder_outputs,
                                                 multiplier=BEAM_WIDTH)
            encoder_state = nest.map_structure(
                lambda s: seq2seq.tile_batch(t=s, multiplier=BEAM_WIDTH),
                self.encoder_state)
            encoder_length = seq2seq.tile_batch(t=self.query_length,
                                                multiplier=BEAM_WIDTH)

        # attention wrapper
        self.attention_mechanism = seq2seq.BahdanauAttention(
            num_units=ENCODER_HIDDEN_SIZE,
            memory=encoder_outputs,
            memory_sequence_length=encoder_length)
        self.decoder_cell = seq2seq.AttentionWrapper(
            rnn.BasicLSTMCell(DECODER_HIDDEN_SIZE),
            attention_mechanism=self.attention_mechanism,
            attention_layer_size=ATTENTION_SIZE)
        self.decoder_initial_state = self.decoder_cell.zero_state(
            batch_size=batch_size,
            dtype=tf.float32).clone(cell_state=encoder_state)

        self.decoder_dense = tf.layers.Dense(
            self.vocab_size,
            dtype=tf.float32,
            use_bias=False,
            kernel_initializer=tf.truncated_normal_initializer(mean=0.0,
                                                               stddev=0.1))

        # 如果是训练过程,使用training helper, 否则使用greedyhelper或beamsearch helper
        if self.mode == "train":
            training_helper = seq2seq.TrainingHelper(
                inputs=self.answer_embeddings,
                sequence_length=self.answer_length)
            training_decoder = seq2seq.BasicDecoder(
                cell=self.decoder_cell,
                helper=training_helper,
                initial_state=self.decoder_initial_state,
                output_layer=self.decoder_dense)

            decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=training_decoder,
                impute_finished=True,
                maximum_iterations=self.max_decode_step)
            self.decoder_logits = tf.identity(decoder_outputs.rnn_output)

            self.loss = seq2seq.sequence_loss(
                logits=decoder_outputs.rnn_output,
                targets=self.answer_target,
                weights=self.sequence_mask)
            self.sample_ids = decoder_outputs.sample_id

            self.optimizer = tf.train.AdamOptimizer(LR_RATE)
            self.train_op = self.optimizer.minimize(self.loss)

            tf.summary.scalar('loss', self.loss)
            self.summary_op = tf.summary.merge_all()
        elif self.mode == "decode":
            start_tokens = tf.ones([self.batch_size], tf.int32) * self.go
            end_token = self.eos

            # 在beam search的情况下,给beam search helper传递的值,不需要使用BEAM_WIDTH的tensor
            # 此处使用beam_search/greedy helper解码都可以,如果只回复1条时等价
            if USE_BEAMSEARCH:
                inference_decoder = seq2seq.BeamSearchDecoder(
                    cell=self.decoder_cell,
                    embedding=self.embeddings_matrix,
                    start_tokens=start_tokens,
                    end_token=end_token,
                    initial_state=self.decoder_initial_state,
                    beam_width=BEAM_WIDTH,
                    output_layer=self.decoder_dense)
                # 使用beam_search的时候,结果是predicted_ids, beam_search_decoder_output
                # predicted_ids: [batch_size, decoder_targets_length, beam_size]
                # beam_search_decoder_output: scores, predicted_ids, parent_ids
                decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder=inference_decoder,
                    maximum_iterations=self.max_decode_step)
                self.sample_ids = decoder_outputs.predicted_ids
                self.sample_ids = tf.transpose(self.sample_ids,
                                               perm=[0, 2, 1])  # 转置成行句子
            else:
                decoding_helper = seq2seq.GreedyEmbeddingHelper(
                    start_tokens=start_tokens,
                    end_token=end_token,
                    embedding=self.embeddings_matrix)
                inference_decoder = seq2seq.BasicDecoder(
                    cell=self.decoder_cell,
                    helper=decoding_helper,
                    initial_state=self.decoder_initial_state,
                    output_layer=self.decoder_dense)
                # 不使用beam_search的时候,结果是rnn_outputs, sample_id,
                # rnn_output: [batch_size, decoder_targets_length, vocab_size]
                # sample_id: [batch_size, decoder_targets_length], tf.int32
                self.decoder_outputs_decode, self.final_state, _ = seq2seq.dynamic_decode(
                    decoder=inference_decoder,
                    maximum_iterations=self.max_decode_step)
                self.sample_ids = self.decoder_outputs_decode.sample_id
Beispiel #27
0
    def sample(self,
               n,
               max_length=None,
               z=None,
               temperature=None,
               start_inputs=None,
               beam_width=None,
               end_token=None):
        """Overrides BaseLstmDecoder `sample` method to add optional beam search.

    Args:
      n: Scalar number of samples to return.
      max_length: (Optional) Scalar maximum sample length to return. Required if
        data representation does not include end tokens.
      z: (Optional) Latent vectors to sample from. Required if model is
        conditional. Sized `[n, z_size]`.
      temperature: (Optional) The softmax temperature to use when not doing beam
        search. Defaults to 1.0. Ignored when `beam_width` is provided.
      start_inputs: (Optional) Initial inputs to use for batch.
        Sized `[n, output_depth]`.
      beam_width: (Optional) Width of beam to use for beam search. Beam search
        is disabled if not provided.
      end_token: (Optional) Scalar token signaling the end of the sequence to
        use for early stopping.
    Returns:
      samples: Sampled sequences. Sized `[n, max_length, output_depth]`.
    Raises:
      ValueError: If `z` is provided and its first dimension does not equal `n`.
    """
        if beam_width is None:
            end_fn = (None if end_token is None else
                      lambda x: tf.equal(tf.argmax(x, axis=-1), end_token))
            return super(CategoricalLstmDecoder,
                         self).sample(n, max_length, z, temperature,
                                      start_inputs, end_fn)

        # If `end_token` is not given, use an impossible value.
        end_token = self._output_depth if end_token is None else end_token
        if z is not None and z.shape[0].value != n:
            raise ValueError(
                '`z` must have a first dimension that equals `n` when given. '
                'Got: %d vs %d' % (z.shape[0].value, n))

        if temperature is not None:
            tf.logging.warning(
                '`temperature` is ignored when using beam search.')
        # Use a dummy Z in unconditional case.
        z = tf.zeros((n, 0), tf.float32) if z is None else z

        # If not given, start with dummy `-1` token and replace with zero vectors in
        # `embedding_fn`.
        start_tokens = (tf.argmax(start_inputs, axis=-1, output_type=tf.int32)
                        if start_inputs is not None else -1 *
                        tf.ones([n], dtype=tf.int32))

        initial_state = initial_cell_state_from_embedding(
            self._dec_cell, z, name='decoder/z_to_initial_state')
        beam_initial_state = seq2seq.tile_batch(initial_state,
                                                multiplier=beam_width)

        # Tile `z` across beams.
        beam_z = tf.tile(tf.expand_dims(z, 1), [1, beam_width, 1])

        def embedding_fn(tokens):
            # If tokens are the start_tokens (negative), replace with zero vectors.
            next_inputs = tf.cond(
                tf.less(tokens[0, 0], 0),
                lambda: tf.zeros([n, beam_width, self._output_depth]),
                lambda: tf.one_hot(tokens, self._output_depth))

            # Concatenate `z` to next inputs.
            next_inputs = tf.concat([next_inputs, beam_z], axis=-1)
            return next_inputs

        decoder = seq2seq.BeamSearchDecoder(self._dec_cell,
                                            embedding_fn,
                                            start_tokens,
                                            end_token,
                                            beam_initial_state,
                                            beam_width,
                                            output_layer=self._output_layer,
                                            length_penalty_weight=0.0)

        final_output, _, _ = seq2seq.dynamic_decode(
            decoder,
            maximum_iterations=max_length,
            swap_memory=True,
            scope='decoder')

        return tf.one_hot(final_output.predicted_ids[:, :, 0],
                          self._output_depth)
Beispiel #28
0
    def build_model(self):
        """
        build model
        :return:
        """
        print('Building model...')
        # 1 定义模型的placeholder
        self.encoder_inputs = tf.placeholder(tf.int32, [None, None],
                                             name='encoder_inputs')
        self.encoder_inputs_length = tf.placeholder(
            tf.int32, [None], name='encoder_inputs_length')

        self.batch_size = tf.placeholder(tf.int32, [], name='batch_size')
        self.keep_prob_dropout = tf.placeholder(tf.float32,
                                                name='keep_prob_dropout')

        self.decoder_targets = tf.placeholder(tf.int32, [None, None],
                                              name='decoder_targets')
        self.decoder_targets_length = tf.placeholder(
            tf.int32, [None], name='decoder_targets_length')

        # 根据目标序列长度,选出其中最大值,然后使用该值构建序列长度的mask标志。
        """
        tf.sequence_mask():
            tf.sequence_mask([1, 3, 2], 5)
            [[ True False False False False]
             [ True  True True False False]
             [ True  True  False False False]]
        """
        self.max_target_sequence_length = tf.reduce_max(
            self.decoder_targets_length, name='max_target_len')
        self.mask = tf.sequence_mask(self.decoder_targets_length,
                                     self.max_target_sequence_length,
                                     dtype=tf.float32,
                                     name='masks')
        # 2 定义模型的encoder部分
        with tf.variable_scope('encoder'):
            # 创建LSTMCell,两层+dropout
            encoder_cell = self.create_rnn_cell()
            # 构建Embedding矩阵,encoder和decoder共用该词向量矩阵
            # embedding.shape = (vocab_size, embedding_size)
            # encoder_inputs.shape = (batch_size, encoder_inputs_length)
            # encoder_inputs_embedded.shape = (batch_size, encoder_inputs_length, embedding_size)
            embedding = tf.get_variable('embedding',
                                        [self.vocab_size, self.embedding_size])
            encoder_inputs_embedded = tf.nn.embedding_lookup(
                embedding, self.encoder_inputs)
            # 使用dynamic_rnn构建LSTM模型,将输入编码成隐层向量
            # encoder_outputs用于attention,batch_size*encoder_inputs_length*rnn_size
            # encoder_state用于decoder的初始状态,batch_size*rnn_size
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                encoder_cell,
                encoder_inputs_embedded,
                sequence_length=self.encoder_inputs_length,
                dtype=tf.float32)
        # 3 定义模型的decoder部分
        with tf.variable_scope('decoder'):
            encoder_inputs_length = self.encoder_inputs_length
            if self.beam_search:
                # 如果使用beam_search,则需要将encoder的输出进行tile_batch,复制beam_size份
                print("use beam search decoding...")
                encoder_outputs = seq2seq.tile_batch(encoder_outputs,
                                                     multiplier=self.beam_size)
                encoder_state = nest.map_structure(
                    lambda s: seq2seq.tile_batch(s, self.beam_size),
                    encoder_state)
                encoder_inputs_length = seq2seq.tile_batch(
                    self.encoder_inputs_length, multiplier=self.beam_size)

            # 定义要使用的attention机制
            attention_mechanism = seq2seq.BahdanauAttention(
                num_units=self.rnn_size,
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_length)
            # attention_mechanism = seq2seq.LuongAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)
            # 定义decoder阶段要使用的LSTMCell,然后为其封装attention wrapper
            decoder_cell = self.create_rnn_cell()
            decoder_cell = seq2seq.AttentionWrapper(
                cell=decoder_cell,
                attention_mechanism=attention_mechanism,
                attention_layer_size=self.rnn_size,
                name='Attention_Wrapper')
            # 如果使用beam_search则batch_size = self.batch_size * self.beam_size
            batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size
            # 定义decoder阶段的初始状态,直接使用encoder阶段的最后一个隐层状态进行赋值
            decoder_initial_state = decoder_cell.zero_state(
                batch_size=batch_size,
                dtype=tf.float32).clone(cell_state=encoder_state)
            output_layer = tf.layers.Dense(
                self.vocab_size,
                kernel_initializer=tf.truncated_normal_initializer(mean=0.0,
                                                                   stddev=0.1))

            if self.mode == 'train':
                # 定义decoder阶段的输入,其实就是在decoder的target开始处添加一个<go>,并删除结尾处的<end>,并进行embedding
                # decoder_inputs_embedded的shape为[batch_size,decoder_targets_length,embedding_size]
                ending = tf.strided_slice(self.decoder_targets, [0, 0],
                                          [self.batch_size, -1], [1, 1])
                decoder_inputs = tf.concat([
                    tf.fill([self.batch_size, 1],
                            tf.cast(self.word2id[self.goToken],
                                    dtype=tf.int32)), ending
                ], 1)
                decoder_inputs_embedded = tf.nn.embedding_lookup(
                    embedding, decoder_inputs)
                # 训练阶段,使用TrainingHelper+BasicDecoder的组合,这一般是固定的,当然也可以自己定义Helper类,实现自己的功能
                training_helper = seq2seq.TrainingHelper(
                    inputs=decoder_inputs_embedded,
                    sequence_length=self.decoder_targets_length,
                    time_major=False,
                    name='training_helper')
                training_decoder = seq2seq.BasicDecoder(
                    cell=decoder_cell,
                    helper=training_helper,
                    initial_state=decoder_initial_state,
                    output_layer=output_layer)
                # 调用dynamic_decoder进行解码,decoder_outputs的一个named tuple,里面包含两项(rnn_outputs, sample_id)
                # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decoder每个时刻每个单词的概率,可以用来计算loss
                # sample_id: [batch_size], tf.int32, 保存最终的编码结果,可以表示最后的答案
                decoder_outputs, _, _ = seq2seq.dynamic_decode(
                    decoder=training_decoder,
                    impute_finished=True,
                    maximum_iterations=self.max_target_sequence_length)
                # 根据输出计算loss和梯度,并定义进行更新的AdamOptimizer和train_op
                self.decoder_logits_train = tf.identity(
                    decoder_outputs.rnn_output)
                self.decoder_predict_train = tf.argmax(
                    self.decoder_logits_train,
                    axis=-1,
                    name='decoder_pred_train')
                # 使用sequence_loss计算loss,这里需要传入之前定义的mask标志
                self.loss = seq2seq.sequence_loss(
                    logits=self.decoder_logits_train,
                    targets=self.decoder_targets,
                    weights=self.mask)
                # Training summary for the current batch_loss
                tf.summary.scalar('loss', self.loss)
                self.summary_op = tf.summary.merge_all()

                optimizer = tf.train.AdamOptimizer(self.learning_rate)
                trainable_params = tf.trainable_variables()
                gradients = tf.gradients(self.loss, trainable_params)
                clip_gradients, _ = tf.clip_by_global_norm(
                    gradients, self.max_gradient_norm)
                self.train_op = optimizer.apply_gradients(
                    zip(clip_gradients, trainable_params))
            elif self.mode == 'predict':
                start_tokens = tf.ones([
                    self.batch_size,
                ], tf.int32) * self.word2id[self.goToken]
                end_token = self.word2id[self.endToken]
                # decoder阶段根据是否使用beam_search决定不同的组合
                # 如果使用则直接调用BeamSearchDecoder(里面已经实现了helper类)
                # 如果不使用则调用GreedyEmbeddingHelper+BasicDecoder的组合进行贪婪式解码
                if self.beam_search:
                    inference_decoder = seq2seq.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embedding,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=decoder_initial_state,
                        beam_width=self.beam_size,
                        output_layer=output_layer)
                else:
                    decoder_helper = seq2seq.GreedyEmbeddingHelper(
                        embedding=embedding,
                        start_tokens=start_tokens,
                        end_token=end_token)
                    inference_decoder = seq2seq.BasicDecoder(
                        cell=decoder_cell,
                        helper=decoder_helper,
                        initial_state=decoder_initial_state,
                        output_layer=output_layer)
                decoder_outputs, _, _ = seq2seq.dynamic_decode(
                    decoder=inference_decoder, maximum_iterations=10)
                # 调用dynamic_decoder进行编码,decoder_outputs是一个named tuple
                # 对于不使用beam_search的时候,它里面包含两项(rnn_outputs,sample_id)
                # rnn_output: [batch_size, decoder_targets_length, vocab_size]
                # sample_id: [batch_size, decoder_targets_length], tf.int32

                # 对于使用beam search的时候,它里面包含两项(predicted_ids, beam_search_decoder_outputs)
                # predicted_ids: [batch_size, decoder_targets_length, beam_size], 保存输出结果
                # beam_search_decoder_output: BeamSearchDecoderOutput instance named tuple (scores, predicted_ids, parent_ids)
                # 所以对应只需要返回predicted_ids或者sample_id即可翻译成最终结果
                if self.beam_search:
                    self.decoder_predict_decoder = decoder_outputs.predicted_ids
                else:
                    self.decoder_predict_decoder = tf.expand_dims(
                        decoder_outputs.sample_id, -1)
        # 4 保存模型
        self.saver = tf.train.Saver(tf.global_variables())
Beispiel #29
0
  def call(self, inputs, training=None, mask=None):
    dec_emb_fn = lambda ids: self.embed(ids)
    if self.is_infer:
      enc_outputs, enc_state, enc_seq_len = inputs
      batch_size = tf.shape(enc_outputs)[0]
      helper = seq2seq.GreedyEmbeddingHelper(embedding=dec_emb_fn,
                                             start_tokens=tf.fill([batch_size],
                                                                  self.dec_start_id),
                                             end_token=self.dec_end_id)
    else:
      dec_inputs, dec_seq_len, enc_outputs, enc_state, \
      enc_seq_len = inputs
      batch_size = tf.shape(enc_outputs)[0]
      dec_inputs = self.embed(dec_inputs)
      helper = seq2seq.TrainingHelper(inputs=dec_inputs,
                                      sequence_length=dec_seq_len)

    if self.is_infer and self.beam_size > 1:
      tiled_enc_outputs = seq2seq.tile_batch(enc_outputs,
                                             multiplier=self.beam_size)
      tiled_seq_len = seq2seq.tile_batch(enc_seq_len,
                                         multiplier=self.beam_size)
      attn_mech = self._build_attention(enc_outputs=tiled_enc_outputs,
                                        enc_seq_len=tiled_seq_len)
      dec_cell = seq2seq.AttentionWrapper(self.cell, attn_mech)
      tiled_enc_last_state = seq2seq.tile_batch(enc_state,
                                                multiplier=self.beam_size)
      tiled_dec_init_state = dec_cell.zero_state(batch_size=batch_size * self.beam_size,
                                                 dtype=tf.float32)
      if self.initial_decode_state:
        tiled_dec_init_state = tiled_dec_init_state.clone(cell_state=tiled_enc_last_state)

      dec = seq2seq.BeamSearchDecoder(cell=dec_cell,
                                      embedding=dec_emb_fn,
                                      start_tokens=tf.tile([self.dec_start_id],
                                                           [batch_size]),
                                      end_token=self.dec_end_id,
                                      initial_state=tiled_dec_init_state,
                                      beam_width=self.beam_size,
                                      output_layer=tf.layers.Dense(self.vocab_size),
                                      length_penalty_weight=self.length_penalty)
    else:
      attn_mech = self._build_attention(enc_outputs=enc_outputs,
                                        enc_seq_len=enc_seq_len)
      dec_cell = seq2seq.AttentionWrapper(cell=self.cell,
                                          attention_mechanism=attn_mech)
      dec_init_state = dec_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
      if self.initial_decode_state:
        dec_init_state = dec_init_state.clone(cell_state=enc_state)
      dec = seq2seq.BasicDecoder(cell=dec_cell,
                                 helper=helper,
                                 initial_state=dec_init_state,
                                 output_layer=tf.layers.Dense(self.vocab_size))
    if self.is_infer:
      dec_outputs, _, _ = \
        seq2seq.dynamic_decode(decoder=dec,
                               maximum_iterations=self.max_dec_len,
                               swap_memory=self.swap_memory,
                               output_time_major=self.time_major)
      return dec_outputs.predicted_ids[:, :, 0]
    else:
      dec_outputs, _, _ = \
        seq2seq.dynamic_decode(decoder=dec,
                               maximum_iterations=tf.reduce_max(dec_seq_len),
                               swap_memory=self.swap_memory,
                               output_time_major=self.time_major)
    return dec_outputs.rnn_output
Beispiel #30
0
 def __call__(self, 
         top_k_attributes,
         mean_image_features=None, 
         mean_object_features=None, 
         spatial_image_features=None, 
         spatial_object_features=None, 
         seq_inputs=None, lengths=None ):
     assert(mean_image_features is not None or mean_object_features is not None or
         spatial_image_features is not None or spatial_object_features is not None)
     attribute_features = tf.nn.embedding_lookup(self.attribute_embeddings_map, top_k_attributes)
     mean_attribute_features = tf.reduce_mean(attribute_features, [1])
     use_beam_search = (seq_inputs is None or lengths is None)
     if mean_image_features is not None:
         batch_size = tf.shape(mean_image_features)[0]
         mean_image_features = tf.concat([mean_image_features, mean_attribute_features], 1)
     elif mean_object_features is not None:
         batch_size = tf.shape(mean_object_features)[0]
         mean_object_features = tf.concat([mean_object_features, attribute_features], 1)
     elif spatial_image_features is not None:
         batch_size = tf.shape(spatial_image_features)[0]
         spatial_image_features = collapse_dims(spatial_image_features, [1, 2])
         mean_image_features = tf.concat([tf.reduce_mean(spatial_image_features, [1]), 
             mean_attribute_features], 1)
         spatial_image_features = tf.concat([spatial_image_features, attribute_features], 1)
     elif spatial_object_features is not None:
         batch_size = tf.shape(spatial_object_features)[0] 
         spatial_object_features = collapse_dims(spatial_object_features, [2, 3])
         mean_object_features = tf.concat([tf.reduce_mean(spatial_object_features, [2]), 
             attribute_features], 1)
         spatial_object_features = tf.concat([spatial_object_features, 
             tf.expand_dims(attribute_features, 2)], 2)
     initial_state = self.image_caption_cell.zero_state(batch_size, tf.float32)
     if use_beam_search:
         if mean_image_features is not None:
             mean_image_features = seq2seq.tile_batch(mean_image_features, 
                 multiplier=self.beam_size)
             self.image_caption_cell.mean_image_features = mean_image_features
         if mean_object_features is not None:
             mean_object_features = seq2seq.tile_batch(mean_object_features, 
                 multiplier=self.beam_size)
             self.image_caption_cell.mean_object_features = mean_object_features
         if spatial_image_features is not None:
             spatial_image_features = seq2seq.tile_batch(spatial_image_features, 
                 multiplier=self.beam_size)
             self.image_caption_cell.spatial_image_features = spatial_image_features
         if spatial_object_features is not None:
             spatial_object_features = seq2seq.tile_batch(spatial_object_features, 
                 multiplier=self.beam_size)
             self.image_caption_cell.spatial_object_features = spatial_object_features
         initial_state = seq2seq.tile_batch(initial_state, multiplier=self.beam_size)
         decoder = seq2seq.BeamSearchDecoder(self.image_caption_cell, self.word_embeddings_map, 
             tf.fill([batch_size], self.word_vocabulary.start_id), self.word_vocabulary.end_id, 
             initial_state, self.beam_size, output_layer=self.word_logits_layer)
         outputs, state, lengths = seq2seq.dynamic_decode(decoder, 
             maximum_iterations=self.maximum_iterations)
         ids = tf.transpose(outputs.predicted_ids, [0, 2, 1])
         sequence_length = tf.shape(ids)[2]
         flat_ids = tf.reshape(ids, [batch_size * self.beam_size, sequence_length])
         seq_inputs = tf.concat([
             tf.fill([batch_size * self.beam_size, 1], self.word_vocabulary.start_id), flat_ids], 1)
     if mean_image_features is not None:
         self.image_caption_cell.mean_image_features = mean_image_features
     if mean_object_features is not None:
         self.image_caption_cell.mean_object_features = mean_object_features
     if spatial_image_features is not None:
         self.image_caption_cell.spatial_image_features = spatial_image_features
     if spatial_object_features is not None:
         self.image_caption_cell.spatial_object_features = spatial_object_features   
     activations, _state = tf.nn.dynamic_rnn(self.image_caption_cell, 
         tf.nn.embedding_lookup(self.word_embeddings_map, seq_inputs),
         sequence_length=tf.reshape(lengths, [-1]), initial_state=initial_state)
     logits = self.word_logits_layer(activations)
     if use_beam_search:
         length = tf.shape(logits)[1]
         logits = tf.reshape(logits, [batch_size, self.beam_size, length, self.vocab_size])
     return logits, tf.argmax(logits, axis=-1, output_type=tf.int32)