Esempio n. 1
0
    def _add_ml_loss(self, is_training):
        logits = self._add_decoder(is_training=is_training,
                                   decoder_inputs=self.decoder_inputs)

        with self.graph.as_default():
            self.preds = tf.to_int32(tf.argmax(
                logits, axis=-1))  # shape: (batch_size, max_timestep)
            self.istarget = tf.to_float(tf.not_equal(
                self.xy, 0))  # shape: (batch_size, max_timestep)
            self.acc_full = tf.reduce_sum(
                tf.to_float(tf.equal(self.preds, self.xy)) *
                self.istarget) / (tf.reduce_sum(self.istarget))
            self.acc_sum = tf.reduce_sum(
                tf.to_float(
                    tf.equal(self.preds[hp.article_maxlen + 1:],
                             self.xy[hp.article_maxlen + 1:])) *
                self.istarget) / (tf.reduce_sum(self.istarget))

            self.rouge = tf.reduce_sum(
                rouge_l_fscore(self.preds[hp.article_maxlen + 1:],
                               self.xy[hp.article_maxlen + 1:])) / float(
                                   hp.batch_size)

            ml_loss = -100
            if is_training:
                # Loss
                self.y_smoothed = label_smoothing(
                    tf.one_hot(self.xy, depth=self.vocab_size))
                loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                    logits=self.logits, labels=self.y_smoothed)
                ml_loss = tf.reduce_sum(loss * self.istarget,
                                        name='fake_ml_loss') / (tf.reduce_sum(
                                            self.istarget))

        return ml_loss
Esempio n. 2
0
    def _add_rl_loss(self):
        sample_logits, sample_preds = self._rl_autoinfer(greedy=tf.constant(
            value=False, dtype=tf.bool),
                                                         name='sample_loop')
        greedy_logits, greedy_preds = self._rl_autoinfer(greedy=tf.constant(
            value=True, dtype=tf.bool),
                                                         name='greedy_loop')
        self.sl = sample_logits

        sample_logits = tf.Print(input_=sample_logits,
                                 data=[sample_logits],
                                 message='sample_logits: ')
        greedy_logits = tf.Print(input_=greedy_logits,
                                 data=[greedy_logits],
                                 message='greedy_logits: ')

        self.reward_diff = tf.zeros(shape=())
        for sent_i, ref in enumerate(tf.unstack(self.y)):
            real_y = ref[:tf.reduce_sum(tf.to_int32(tf.not_equal(
                self.y, 0)))]  # remove the <PAD> in the end
            self.reward_diff += rouge_l_fscore(
                [greedy_preds[sent_i]], [real_y]) - rouge_l_fscore(
                    [sample_preds[sent_i]], [real_y])

        self.reward_diff = tf.Print(input_=self.reward_diff,
                                    data=[self.reward_diff],
                                    message='reward_diff: ')

        self.clipped_reward_diff = tf.math.minimum(x=self.reward_diff,
                                                   y=tf.ones(shape=()) *
                                                   hp.max_reward_diff)
        self.clipped_reward_diff = tf.math.maximum(x=self.clipped_reward_diff,
                                                   y=-tf.ones(shape=()) *
                                                   hp.max_reward_diff)

        # rl_loss = tf.reduce_sum(self.reward_diff * sample_logits) / (hp.batch_size * hp.summary_maxlen)
        rl_loss = tf.reduce_sum(self.clipped_reward_diff * sample_logits) / (
            hp.batch_size * hp.summary_maxlen)
        rl_loss = tf.Print(input_=rl_loss, data=[rl_loss], message='rl_loss: ')
        return rl_loss  # masked
Esempio n. 3
0
def attention_decoder(_hps, 
  v_size, 
  _max_art_oovs, 
  _enc_batch_extend_vocab, 
  emb_dec_inputs,
  target_batch,
  _dec_in_state, 
  _enc_states, 
  enc_padding_mask, 
  dec_padding_mask, 
  cell, 
  embedding, 
  sampling_probability,
  alpha,
  unk_id,
  initial_state_attention=False,
  pointer_gen=True, 
  use_coverage=False, 
  prev_coverage=None, 
  prev_decoder_outputs=[], 
  prev_encoder_es = []):
  """
  Args:
    _hps: parameter of the models.
    v_size: vocab size.
    _max_art_oovs: size of the oov tokens in current batch.
    _enc_batch_extend_vocab: encoder extended vocab batch.
    emb_dec_inputs: A list of 2D Tensors [batch_size x emb_dim].
    target_batch: The indices of the target words. shape (max_dec_steps, batch_size)
    _dec_in_state: 2D Tensor [batch_size x cell.state_size].
    _enc_states: 3D Tensor [batch_size x max_enc_steps x attn_size].
    enc_padding_mask: 2D Tensor [batch_size x max_enc_steps] containing 1s and 0s; indicates which of the encoder locations are padding (0) or a real token (1).
    dec_padding_mask: 2D Tensor [batch_size x max_dec_steps] containing 1s and 0s; indicates which of the decoder locations are padding (0) or a real token (1).
    cell: rnn_cell.RNNCell defining the cell function and size.
    embedding: embedding matrix [vocab_size, emb_dim].
    sampling_probability: sampling probability for scheduled sampling.
    alpha: soft-argmax argument.
    initial_state_attention:
      Note that this attention decoder passes each decoder input through a linear layer with the previous step's context vector to get a modified version of the input. If initial_state_attention is False, on the first decoder step the "previous context vector" is just a zero vector. If initial_state_attention is True, we use _dec_in_state to (re)calculate the previous step's context vector. We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step).
    pointer_gen: boolean. If True, calculate the generation probability p_gen for each decoder step.
    use_coverage: boolean. If True, use coverage mechanism.
    prev_coverage:
      If not None, a tensor with shape (batch_size, max_enc_steps). The previous step's coverage vector. This is only not None in decode mode when using coverage.
    prev_decoder_outputs: if not empty, a tensor of (len(prev_decoder_steps), batch_size, hidden_dim). The previous decoder output used for calculating the intradecoder attention during decode mode
    prev_encoder_es: if not empty, a tensor of (len(prev_encoder_es), batch_size, hidden_dim). The previous attention vector used for calculating the temporal attention during decode mode.
  Returns:
    outputs: A list of the same length as emb_dec_inputs of 2D Tensors of
      shape [batch_size x cell.output_size]. The output vectors.
    state: The final state of the decoder. A tensor shape [batch_size x cell.state_size].
    attn_dists: A list containing tensors of shape (batch_size,max_enc_steps).
      The attention distributions for each decoder step.
    p_gens: List of length emb_dim, containing tensors of shape [batch_size, 1]. The values of p_gen for each decoder step. Empty list if pointer_gen=False.
    coverage: Coverage vector on the last step computed. None if use_coverage=False.
    vocab_scores: vocab distribution.
    final_dists: final output distribution.
    samples: contains sampled tokens.
    greedy_search_samples: contains greedy tokens.
    temporal_e: contains temporal attention.
  """
  with variable_scope.variable_scope("attention_decoder") as scope:
    batch_size = _enc_states.get_shape()[0] # if this line fails, it's because the batch size isn't defined
    attn_size = _enc_states.get_shape()[2] # if this line fails, it's because the attention length isn't defined
    emb_size = emb_dec_inputs[0].get_shape()[1] # if this line fails, it's because the embedding isn't defined
    decoder_attn_size = _dec_in_state.c.get_shape()[1]
    tf.logging.info("batch_size %i, attn_size: %i, emb_size: %i", batch_size, attn_size, emb_size)
    # Reshape _enc_states (need to insert a dim)
    _enc_states = tf.expand_dims(_enc_states, axis=2) # now is shape (batch_size, max_enc_steps, 1, attn_size)

    # To calculate attention, we calculate
    #   v^T tanh(W_h h_i + W_s s_t + b_attn)
    # where h_i is an encoder state, and s_t a decoder state.
    # attn_vec_size is the length of the vectors v, b_attn, (W_h h_i) and (W_s s_t).
    # We set it to be equal to the size of the encoder states.
    attention_vec_size = attn_size

    # Get the weight matrix W_h and apply it to each encoder state to get (W_h h_i), the encoder features
    if _hps.matrix_attention:
      w_attn = variable_scope.get_variable("w_attn", [attention_vec_size, attention_vec_size])
      if _hps.intradecoder:
        w_dec_attn = variable_scope.get_variable("w_dec_attn", [decoder_attn_size, decoder_attn_size])
    else:
      W_h = variable_scope.get_variable("W_h", [1, 1, attn_size, attention_vec_size])
      v = variable_scope.get_variable("v", [attention_vec_size])
      encoder_features = nn_ops.conv2d(_enc_states, W_h, [1, 1, 1, 1], "SAME") # shape (batch_size,max_enc_steps,1,attention_vec_size)
    if _hps.intradecoder:
      W_h_d = variable_scope.get_variable("W_h_d", [1, 1, decoder_attn_size, decoder_attn_size])
      v_d = variable_scope.get_variable("v_d", [decoder_attn_size])

    # Get the weight vectors v and w_c (w_c is for coverage)
    if use_coverage:
      with variable_scope.variable_scope("coverage"):
        w_c = variable_scope.get_variable("w_c", [1, 1, 1, attention_vec_size])

    if prev_coverage is not None: # for beam search mode with coverage
      # reshape from (batch_size, max_enc_steps) to (batch_size, max_enc_steps, 1, 1)
      prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage,2),3)

    def attention(decoder_state, temporal_e, coverage=None):
      """Calculate the context vector and attention distribution from the decoder state.

      Args:
        decoder_state: state of the decoder
        temporal_e: store previous attentions for temporal attention mechanism
        coverage: Optional. Previous timestep's coverage vector, shape (batch_size, max_enc_steps, 1, 1).

      Returns:
        context_vector: weighted sum of _enc_states
        attn_dist: attention distribution
        coverage: new coverage vector. shape (batch_size, max_enc_steps, 1, 1)
        masked_e: store the attention score for temporal attention mechanism.
      """
      with variable_scope.variable_scope("Attention"):
        # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper)
        decoder_features = linear(decoder_state, attention_vec_size, True) # shape (batch_size, attention_vec_size)
        decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_vec_size)

        # We can't have coverage with matrix attention
        if not _hps.matrix_attention and use_coverage and coverage is not None: # non-first step of coverage
          # Multiply coverage vector by w_c to get coverage_features.
          coverage_features = nn_ops.conv2d(coverage, w_c, [1, 1, 1, 1], "SAME") # c has shape (batch_size, max_enc_steps, 1, attention_vec_size)
          # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn)
          e_not_masked = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3])  # shape (batch_size,max_enc_steps)
          masked_e = nn_ops.softmax(e_not_masked) * enc_padding_mask # (batch_size, max_enc_steps)
          masked_sums = tf.reduce_sum(masked_e, axis=1) # shape (batch_size)
          masked_e = masked_e / tf.reshape(masked_sums, [-1, 1])
          # Equation 3 in 
          if _hps.use_temporal_attention:
            try:
              len_temporal_e = temporal_e.get_shape()[0]
            except:
              len_temporal_e = 0
            if len_temporal_e==0:
              attn_dist = masked_e
            else:
              masked_sums = tf.reduce_sum(temporal_e,axis=0)+1e-10 # if it's zero due to masking we set it to a small value
              attn_dist = masked_e / masked_sums # (batch_size, max_enc_steps)
          else:
            attn_dist = masked_e
          masked_attn_sums = tf.reduce_sum(attn_dist, axis=1)
          attn_dist = attn_dist / tf.reshape(masked_attn_sums, [-1, 1]) # re-normalize
          # Update coverage vector
          coverage += array_ops.reshape(attn_dist, [batch_size, -1, 1, 1])
        else:
          if _hps.matrix_attention:
            # Calculate h_d * W_attn * h_i, equation 2 in https://arxiv.org/pdf/1705.04304.pdf
            _dec_attn = tf.unstack(tf.matmul(tf.squeeze(decoder_features,axis=[1,2]),w_attn),axis=0) # batch_size * (attention_vec_size)
            _enc_states_lst = tf.unstack(tf.squeeze(_enc_states,axis=2),axis=0) # batch_size * (max_enc_steps, attention_vec_size)

            e_not_masked = tf.squeeze(tf.stack([tf.matmul(tf.reshape(_dec,[1,-1]), tf.transpose(_enc)) for _dec, _enc in zip(_dec_attn,_enc_states_lst)]),axis=1) # (batch_size, max_enc_steps)
            masked_e = tf.exp(e_not_masked * enc_padding_mask) # (batch_size, max_enc_steps)
          else:
            # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
            e_not_masked = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e, (batch_size, max_enc_steps)
            masked_e = nn_ops.softmax(e_not_masked) * enc_padding_mask # (batch_size, max_enc_steps)
            masked_sums = tf.reduce_sum(masked_e, axis=1) # shape (batch_size)
            masked_e = masked_e / tf.reshape(masked_sums, [-1, 1])
          if _hps.use_temporal_attention:
            try:
              len_temporal_e = temporal_e.get_shape()[0]
            except:
              len_temporal_e = 0
            if len_temporal_e==0:
              attn_dist = masked_e
            else:
              masked_sums = tf.reduce_sum(temporal_e,axis=0)+1e-10 # if it's zero due to masking we set it to a small value
              attn_dist = masked_e / masked_sums # (batch_size, max_enc_steps)
          else:
            attn_dist = masked_e
          # Calculate attention distribution
          masked_attn_sums = tf.reduce_sum(attn_dist, axis=1)
          attn_dist = attn_dist / tf.reshape(masked_attn_sums, [-1, 1]) # re-normalize

          if use_coverage: # first step of training
            coverage = tf.expand_dims(tf.expand_dims(attn_dist,2),2) # initialize coverage

        # Calculate the context vector from attn_dist and _enc_states
        context_vector = math_ops.reduce_sum(array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) * _enc_states, [1, 2]) # shape (batch_size, attn_size).
        context_vector = array_ops.reshape(context_vector, [-1, attn_size])

      return context_vector, attn_dist, coverage, masked_e

    def intra_decoder_attention(decoder_state, outputs):
      """Calculate the context vector and attention distribution from the decoder state.

      Args:
        decoder_state: state of the decoder
        outputs: list of decoder states for implementing intra-decoder mechanism, len(decoder_states) * (batch_size, hidden_dim)
      Returns:
        context_decoder_vector: weighted sum of _dec_states
        decoder_attn_dist: intra-decoder attention distribution
      """
      attention_dec_vec_size = attn_dec_size = decoder_state.c.get_shape()[1] # hidden_dim
      try:
        len_dec_states = outputs.get_shape()[0]
      except:
        len_dec_states = 0
      attention_dec_vec_size = attn_dec_size = decoder_state.c.get_shape()[1] # hidden_dim
      _decoder_states = tf.expand_dims(tf.reshape(outputs,[batch_size,-1,attn_dec_size]), axis=2) # now is shape (batch_size,len(decoder_states), 1, attn_size)
      _prev_decoder_features = nn_ops.conv2d(_decoder_states, W_h_d, [1, 1, 1, 1], "SAME") # shape (batch_size,len(decoder_states),1,attention_vec_size)
      with variable_scope.variable_scope("DecoderAttention"):
        # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper)
        try:
          decoder_features = linear(decoder_state, attention_dec_vec_size, True) # shape (batch_size, attention_vec_size)
          decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_dec_vec_size)
          # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
          if _hps.matrix_attention:
            # Calculate h_d * W_attn * h_d, equation 6 in https://arxiv.org/pdf/1705.04304.pdf
            _dec_attn = tf.matmul(tf.squeeze(decoder_features),w_dec_attn) # (batch_size, decoder_attn_size)
            _dec_states_lst = tf.unstack(tf.reshape(_prev_decoder_features,[batch_size,-1,decoder_attn_size])) # batch_size * (len(decoder_states), decoder_attn_size)
            e_not_masked = tf.reshape(tf.stack([tf.matmul(_dec_attn, tf.transpose(k)) for k in _dec_states_lst]),[batch_size,-1]) # (batch_size, len(decoder_states))
            masked_e = tf.exp(e_not_masked * dec_padding_mask[:,:len_dec_states]) # (batch_size, len(decoder_states))
          else:
            # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
            e_not_masked = math_ops.reduce_sum(v_d * math_ops.tanh(_prev_decoder_features + decoder_features), [2, 3]) # calculate e, (batch_size,len(decoder_states))
            masked_e = nn_ops.softmax(e_not_masked) * dec_padding_mask[:,:len_dec_states] # (batch_size,len(decoder_states))
          if len_dec_states <= 1:
            masked_e = array_ops.ones([batch_size,1]) # first step is filled with equal values
          masked_sums = tf.reshape(tf.reduce_sum(masked_e,axis=1),[-1,1]) # (batch_size,1), # if it's zero due to masking we set it to a small value
          decoder_attn_dist = masked_e / masked_sums # (batch_size,len(decoder_states))
          context_decoder_vector = math_ops.reduce_sum(array_ops.reshape(decoder_attn_dist, [batch_size, -1, 1, 1]) * _decoder_states, [1, 2]) # (batch_size, attn_size)
          context_decoder_vector = array_ops.reshape(context_decoder_vector, [-1, attn_dec_size]) # (batch_size, attn_size)
        except:
          return array_ops.zeros([batch_size, decoder_attn_size]), array_ops.zeros([batch_size, 0])
      return context_decoder_vector, decoder_attn_dist

    outputs = []
    temporal_e = []
    attn_dists = []
    vocab_scores = []
    vocab_dists = []
    final_dists = []
    p_gens = []
    samples = [] # this holds the words chosen by sampling based on the final distribution for each decoding step, list of max_dec_steps of (batch_size, 1)
    greedy_search_samples = [] # this holds the words chosen by greedy search (taking the max) on the final distribution for each decoding step, list of max_dec_steps of (batch_size, 1)
    sampling_rewards = [] # list of size max_dec_steps (batch_size, k)
    greedy_rewards = [] # list of size max_dec_steps (batch_size, k)
    state = _dec_in_state
    coverage = prev_coverage # initialize coverage to None or whatever was passed in
    context_vector = array_ops.zeros([batch_size, attn_size])
    context_decoder_vector = array_ops.zeros([batch_size, decoder_attn_size])
    context_vector.set_shape([None, attn_size])  # Ensure the second shape of attention vectors is set.
    if initial_state_attention: # true in decode mode
      # Re-calculate the context vector from the previous step so that we can pass it through a linear layer with this step's input to get a modified version of the input
      context_vector, _, coverage, _ = attention(_dec_in_state, tf.stack(prev_encoder_es,axis=0), coverage) # in decode mode, this is what updates the coverage vector
      if _hps.intradecoder:
        context_decoder_vector, _ = intra_decoder_attention(_dec_in_state, tf.stack(prev_decoder_outputs,axis=0))
    for i, inp in enumerate(emb_dec_inputs):
      tf.logging.info("Adding attention_decoder timestep %i of %i", i, len(emb_dec_inputs))
      if i > 0:
        variable_scope.get_variable_scope().reuse_variables()

      if _hps.mode in ['train','eval'] and _hps.scheduled_sampling and i > 0: # start scheduled sampling after we received the first decoder's output
        # modify the input to next decoder using scheduled sampling
        if FLAGS.scheduled_sampling_final_dist:
          inp = scheduled_sampling(_hps, sampling_probability, final_dist, embedding, inp, alpha)
        else:
          inp = scheduled_sampling_vocab_dist(_hps, sampling_probability, vocab_dist, embedding, inp, alpha)

      # Merge input and previous attentions into one vector x of the same size as inp
      emb_dim = inp.get_shape().with_rank(2)[1]
      if emb_dim is None:
        raise ValueError("Could not infer input size from input: %s" % inp.name)

      x = linear([inp] + [context_vector], emb_dim, True)
      # Run the decoder RNN cell. cell_output = decoder state
      cell_output, state = cell(x, state)

      # Run the attention mechanism.
      if i == 0 and initial_state_attention:  # always true in decode mode
        with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): # you need this because you've already run the initial attention(...) call
          context_vector, attn_dist, _, masked_e = attention(state, tf.stack(prev_encoder_es,axis=0), coverage) # don't allow coverage to update
          if _hps.intradecoder:
            context_decoder_vector, _ = intra_decoder_attention(state, tf.stack(prev_decoder_outputs,axis=0))
      else:
        context_vector, attn_dist, coverage, masked_e = attention(state, tf.stack(temporal_e,axis=0), coverage)
        if _hps.intradecoder:
          context_decoder_vector, _ = intra_decoder_attention(state, tf.stack(outputs,axis=0))
      attn_dists.append(attn_dist)
      temporal_e.append(masked_e)

      with variable_scope.variable_scope("combined_context"):
        if _hps.intradecoder:
          context_vector = linear([context_vector] + [context_decoder_vector], attn_size, False)
      # Calculate p_gen
      if pointer_gen:
        with tf.variable_scope('calculate_pgen'):
          p_gen = linear([context_vector, state.c, state.h, x], 1, True) # Tensor shape (batch_size, 1)
          p_gen = tf.sigmoid(p_gen)
          p_gens.append(p_gen)

      # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer
      # This is V[s_t, h*_t] + b in the paper
      with variable_scope.variable_scope("AttnOutputProjection"):
        output = linear([cell_output] + [context_vector], cell.output_size, True)
      outputs.append(output)

      # Add the output projection to obtain the vocabulary distribution
      with tf.variable_scope('output_projection'):
        if i > 0:
          tf.get_variable_scope().reuse_variables()
        trunc_norm_init = tf.truncated_normal_initializer(stddev=_hps.trunc_norm_init_std)
        w_out = tf.get_variable('w', [_hps.dec_hidden_dim, v_size], dtype=tf.float32, initializer=trunc_norm_init)
        #w_t_out = tf.transpose(w)
        v_out = tf.get_variable('v', [v_size], dtype=tf.float32, initializer=trunc_norm_init)
        if i > 0:
          tf.get_variable_scope().reuse_variables()
        if FLAGS.share_decoder_weights: # Eq. 13 in https://arxiv.org/pdf/1705.04304.pdf
          w_out = tf.transpose(
            math_ops.tanh(linear([embedding] + [tf.transpose(w_out)], _hps.dec_hidden_dim, bias=False)))
        score = tf.nn.xw_plus_b(output, w_out, v_out)
        if _hps.scheduled_sampling and not _hps.greedy_scheduled_sampling:
          # Gumbel reparametrization trick: https://arxiv.org/abs/1704.06970
          U = tf.random_uniform(score.get_shape(),10e-12,(1-10e-12)) # add a small number to avoid log(0)
          G = -tf.log(-tf.log(U))
          score = score + G
        vocab_scores.append(score) # apply the linear layer
        vocab_dist = tf.nn.softmax(score)
        vocab_dists.append(vocab_dist) # The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file.

      # For pointer-generator model, calc final distribution from copy distribution and vocabulary distribution
      if _hps.pointer_gen:
        final_dist = _calc_final_dist(_hps, v_size, _max_art_oovs, _enc_batch_extend_vocab, p_gen, vocab_dist,
                                      attn_dist)
      else: # final distribution is just vocabulary distribution
        final_dist = vocab_dist
      final_dists.append(final_dist)

      # get the sampled token and greedy token
      # this will take the final_dist and sample from it for a total count of k (k samples)
      one_hot_k_samples = tf.distributions.Multinomial(total_count=1., probs=final_dist).sample(
        _hps.k)  # sample k times according to https://arxiv.org/pdf/1705.04304.pdf, size (k, batch_size, extended_vsize)
      k_argmax = tf.argmax(one_hot_k_samples, axis=2, output_type=tf.int32) # (k, batch_size)
      k_sample = tf.transpose(k_argmax) # shape (batch_size, k)
      greedy_search_prob, greedy_search_sample = tf.nn.top_k(final_dist, k=_hps.k) # (batch_size, k)
      greedy_search_samples.append(greedy_search_sample)
      samples.append(k_sample)
      if FLAGS.use_discounted_rewards:
        _sampling_rewards = []
        _greedy_rewards = []
        for _ in range(_hps.k):
          rl_fscore = tf.reshape(rouge_l_fscore(tf.transpose(tf.stack(samples)[:, :, _]), target_batch),
                                 [-1, 1])  # shape (batch_size, 1)
          _sampling_rewards.append(tf.reshape(rl_fscore, [-1, 1]))
          rl_fscore = tf.reshape(rouge_l_fscore(tf.transpose(tf.stack(greedy_search_samples)[:, :, _]), target_batch),
                                 [-1, 1])  # shape (batch_size, 1)
          _greedy_rewards.append(tf.reshape(rl_fscore, [-1, 1]))
        sampling_rewards.append(tf.squeeze(tf.stack(_sampling_rewards, axis=1), axis = -1)) # (batch_size, k)
        greedy_rewards.append(tf.squeeze(tf.stack(_greedy_rewards, axis=1), axis = -1))  # (batch_size, k)

    if FLAGS.use_discounted_rewards:
      sampling_rewards = tf.stack(sampling_rewards)
      greedy_rewards = tf.stack(greedy_rewards)
    else:
      _sampling_rewards = []
      _greedy_rewards = []
      for _ in range(_hps.k):
        rl_fscore = rouge_l_fscore(tf.transpose(tf.stack(samples)[:, :, _]), target_batch) # shape (batch_size, 1)
        _sampling_rewards.append(tf.reshape(rl_fscore, [-1, 1]))
        rl_fscore = rouge_l_fscore(tf.transpose(tf.stack(greedy_search_samples)[:, :, _]), target_batch)  # shape (batch_size, 1)
        _greedy_rewards.append(tf.reshape(rl_fscore, [-1, 1]))
      sampling_rewards = tf.squeeze(tf.stack(_sampling_rewards, axis=1), axis=-1) # (batch_size, k)
      greedy_rewards = tf.squeeze(tf.stack(_greedy_rewards, axis=1), axis=-1) # (batch_size, k)
    # If using coverage, reshape it
    if coverage is not None:
      coverage = array_ops.reshape(coverage, [batch_size, -1])

  return (
  outputs, state, attn_dists, p_gens, coverage, vocab_scores, final_dists, samples, greedy_search_samples, temporal_e,
  sampling_rewards, greedy_rewards)
Esempio n. 4
0
  def _add_shared_loss_op(self):
    # Calculate the loss
    with tf.variable_scope('shared_loss'):
      # Calculate the loss per step
      # This is fiddly; we use tf.gather_nd to pick out the probabilities of the gold target words
      #### added by [email protected]: we just calculate these to monitor pgen_loss throughout time
      loss_per_step = [] # will be list length max_dec_steps containing shape (batch_size)
      batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size)
      for dec_step, dist in enumerate(self.final_dists):
        targets = self._target_batch[:,dec_step] # The indices of the target words. shape (batch_size)
        indices = tf.stack( (batch_nums, targets), axis=1) # shape (batch_size, 2)
        gold_probs = tf.gather_nd(dist, indices) # shape (batch_size). prob of correct words on this step
        losses = -tf.log(gold_probs)
        loss_per_step.append(losses)
      self._pgen_loss = _mask_and_avg(loss_per_step, self._dec_padding_mask)
      self.variable_summaries('pgen_loss', self._pgen_loss)
      # Adding Q-Estimation to CE loss in Actor-Critic Model
      if self._hps.ac_training:
        # Calculating Actor-Critic loss
        # Here, we multiple the Q-estimation for each token to its respective probability
        loss_per_step = [] # will be list length k each containing a list of shape <=max_dec_steps which each has the shape (batch_size)
        q_loss_per_step = [] # will be list length k each containing a list of shape <=max_dec_steps which each has the shape (batch_size)
        batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size)
        unstacked_q = tf.unstack(self._q_estimates, axis =1) # list of k with size (batch_size, <=max_dec_steps, vsize_extended)
        for sample_id in range(self._hps.k):
          loss_per_sample = [] # length <=max_dec_steps of batch_sizes
          q_loss_per_sample = [] # length <=max_dec_steps of batch_sizes
          q_val_per_sample = tf.unstack(unstacked_q[sample_id], axis =1) # list of <=max_dec_step (batch_size, vsize_extended)
          for dec_step, (dist, q_value) in enumerate(zip(self.final_dists, q_val_per_sample)):
            targets = tf.squeeze(self.samples[dec_step][:,sample_id]) # The indices of the sampled words. shape (batch_size)
            indices = tf.stack((batch_nums, targets), axis=1) # shape (batch_size, 2)
            gold_probs = tf.gather_nd(dist, indices) # shape (batch_size). prob of correct words on this step
            losses = -tf.log(gold_probs)
            dist_q_val = -tf.log(dist) * q_value
            q_losses = tf.gather_nd(dist_q_val, indices) # shape (batch_size). prob of correct words on this step
            loss_per_sample.append(losses)
            q_loss_per_sample.append(q_losses)
          loss_per_step.append(loss_per_sample)
          q_loss_per_step.append(q_loss_per_sample)
        with tf.variable_scope('reinforce_loss'):
          #### this is the actual loss
          self._rl_avg_logprobs = tf.reduce_mean([_mask_and_avg(loss_per_sample, self._dec_padding_mask) for loss_per_sample in loss_per_step])
          self._rl_loss = tf.reduce_mean([_mask_and_avg(q_loss_per_sample, self._dec_padding_mask) for q_loss_per_sample in q_loss_per_step])
          # Eq. 34 in https://arxiv.org/pdf/1805.09461.pdf
          self._reinforce_shared_loss = self._eta * self._rl_loss + (tf.constant(1.,dtype=tf.float32) - self._eta) * self._pgen_loss # equation 16 in https://arxiv.org/pdf/1705.04304.pdf
          #### the following is only for monitoring purposes
          self.variable_summaries('reinforce_avg_logprobs', self._rl_avg_logprobs)
          self.variable_summaries('reinforce_loss', self._rl_loss)
          self.variable_summaries('reinforce_shared_loss', self._reinforce_shared_loss)

      # Adding Self-Critic Reward to CE loss in Policy-Gradient Model
      if self._hps.rl_training:
        #### Calculating the reinforce loss according to Eq. 15 in https://arxiv.org/pdf/1705.04304.pdf
        loss_per_step = [] # will be list length max_dec_steps*k containing shape (batch_size)
        rl_loss_per_step = [] # will be list length max_dec_steps*k containing shape (batch_size)
        batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size)
        self._sampled_rouges = []
        self._greedy_rouges = []
        self._reward_diff = []
        for _ in range(self._hps.k):
          self._sampled_rouges.append(rouge_l_fscore(self.sampled_sentences[:, _, :], self._target_batch))
          self._greedy_rouges.append(rouge_l_fscore(self.greedy_search_sentences[:, _, :], self._target_batch))
          self._reward_diff.append(self._sampled_rouges[_] - self._greedy_rouges[_])
        for dec_step, dist in enumerate(self.final_dists):
          _targets = self.samples[dec_step] # The indices of the sampled words. shape (batch_size, k)
          for _k, targets in enumerate(tf.unstack(_targets,axis=1)): # list of k samples of size (batch_size)
            indices = tf.stack( (batch_nums, targets), axis=1) # shape (batch_size, 2)
            gold_probs = tf.gather_nd(dist, indices) # shape (batch_size). prob of correct words on this step
            losses = -tf.log(gold_probs)
            loss_per_step.append(losses)
            # Equation 15 in https://arxiv.org/pdf/1705.04304.pdf
            # Equal reward for all tokens
            rl_losses = -tf.log(gold_probs) * self._reward_diff[_k]
            #rl_losses = -tf.log(gold_probs) * (self._sampled_sentence_r_values[_k]-self._greedy_sentence_r_values[_k])
            rl_loss_per_step.append(rl_losses)

        # new size: (k, max_dec_steps, batch_size)
        rl_loss_per_step = tf.unstack(
          tf.transpose(tf.reshape(rl_loss_per_step, [-1, self._hps.k, self._hps.batch_size]),perm=[1,0,2]))
        loss_per_step = tf.unstack(
          tf.transpose(tf.reshape(loss_per_step, [-1, self._hps.k, self._hps.batch_size]), perm=[1, 0, 2]))
        with tf.variable_scope('reinforce_loss'):
          self._rl_avg_logprobs = []
          self._rl_loss = []

          for _k in range(self._hps.k):
            self._rl_avg_logprobs.append(_mask_and_avg(tf.unstack(loss_per_step[_k]), self._dec_padding_mask))
            self._rl_loss.append(_mask_and_avg(tf.unstack(tf.reshape(rl_loss_per_step[_k], [self._hps.max_dec_steps, self._hps.batch_size])), self._dec_padding_mask))

          self._rl_avg_logprobs = tf.reduce_mean(self._rl_avg_logprobs)
          self._rl_loss = tf.reduce_mean(self._rl_loss)
          # We multiply the ROUGE difference of sampling vs greedy sentence to the loss of all tokens in the sequence
          # Eq. 16 in https://arxiv.org/pdf/1705.04304.pdf and Eq. 34 in https://arxiv.org/pdf/1805.09461.pdf
          self._reinforce_shared_loss = self._eta * self._rl_loss + (tf.constant(1.,dtype=tf.float32) - self._eta) * self._pgen_loss
          #### the following is only for monitoring purposes
          self.variable_summaries('reinforce_avg_logprobs', self._rl_avg_logprobs)
          self.variable_summaries('reinforce_loss', self._rl_loss)
          self.variable_summaries('reinforce_sampled_r_value', tf.reduce_mean(self._sampled_rouges))
          self.variable_summaries('reinforce_greedy_r_value', tf.reduce_mean(self._greedy_rouges))
          self.variable_summaries('reinforce_r_diff', tf.reduce_mean(self._reward_diff))
          self.variable_summaries('reinforce_shared_loss', self._reinforce_shared_loss)

      # Calculate coverage loss from the attention distributions
      if self._hps.coverage:
        with tf.variable_scope('coverage_loss'):
          self._coverage_loss = _coverage_loss(self.attn_dists, self._dec_padding_mask)
          self.variable_summaries('coverage_loss', self._coverage_loss)
        if self._hps.rl_training or self._hps.ac_training:
          with tf.variable_scope('reinforce_loss'):
            self._reinforce_cov_total_loss = self._reinforce_shared_loss + self._hps.cov_loss_wt * self._coverage_loss
            self.variable_summaries('reinforce_coverage_loss', self._reinforce_cov_total_loss)
        if self._hps.pointer_gen:
          self._pointer_cov_total_loss = self._pgen_loss + self._hps.cov_loss_wt * self._coverage_loss
          self.variable_summaries('pointer_coverage_loss', self._pointer_cov_total_loss)