Beispiel #1
0
 def _apply_gradients(grads_and_vars, global_step=None):
     if hasattr(optimizer, "allreduce_gradients"):
         grads_and_vars = optimizer.allreduce_gradients(grads_and_vars)
     kwargs = {}
     if "decay_var_list" in misc.function_args(optimizer.apply_gradients):
         kwargs["decay_var_list"] = [
             var for _, var in grads_and_vars if not _is_bias(var)
         ]
     return optimizer.apply_gradients(grads_and_vars,
                                      global_step=global_step,
                                      **kwargs)
Beispiel #2
0
def dynamic_decode(symbols_to_logits_fn,
                   start_ids,
                   end_id=constants.END_OF_SENTENCE_ID,
                   initial_state=None,
                   decoding_strategy=None,
                   sampler=None,
                   maximum_iterations=None,
                   minimum_iterations=0,
                   attention_history=False):
  """Dynamic decoding.

  Args:
    symbols_to_logits_fn: A callable taking ``(symbols, step, state)`` and
      returning ``(logits, state, attention)`` (``attention`` is optional).
    start_ids: Initial input IDs of shape :math:`[B]`.
    end_id: ID of the end of sequence token.
    initial_state: Initial decoder state.
    decoding_strategy: A :class:`opennmt.utils.decoding.DecodingStrategy`
      instance that define the decoding logic. Defaults to a greedy search.
    sampler: A :class:`opennmt.utils.decoding.Sampler` instance that samples
      predictions from the model output. Defaults to an argmax sampling.
    maximum_iterations: The maximum number of iterations to decode for.
    minimum_iterations: The minimum number of iterations to decode for.
    attention_history: Gather attention history during the decoding.

  Returns:
    ids: The predicted ids of shape :math:`[B, H, T]`.
    lengths: The produced sequences length of shape :math:`[B, H]`.
    log_probs: The cumulated log probabilities of shape :math:`[B, H]`.
    attention_history: The attention history of shape :math:`[B, H, T_t, T_s]`.
    state: The final decoding state.
  """
  if "maximum_iterations" not in misc.function_args(tf.while_loop):
    raise NotImplementedError("Unified decoding does not support TensorFlow 1.4. "
                              "Please update your TensorFlow installation or open "
                              "an issue for assistance.")
  if decoding_strategy is None:
    decoding_strategy = GreedySearch()
  if sampler is None:
    sampler = BestSampler()

  def _cond(step, finished, state, inputs, outputs, attention, cum_log_probs, extra_vars):  # pylint: disable=unused-argument
    return tf.reduce_any(tf.logical_not(finished))

  def _body(step, finished, state, inputs, outputs, attention, cum_log_probs, extra_vars):
    # Get log probs from the model.
    result = symbols_to_logits_fn(inputs, step, state)
    logits, state = result[0], result[1]
    attn = result[2] if len(result) > 2 else None
    logits = tf.cast(logits, tf.float32)

    # Penalize or force EOS.
    batch_size, vocab_size = misc.shape_list(logits)
    eos_max_prob = tf.one_hot(
        tf.fill([batch_size], end_id),
        vocab_size,
        on_value=logits.dtype.max,
        off_value=logits.dtype.min)
    logits = tf.cond(
        step < minimum_iterations,
        true_fn=lambda: _penalize_token(logits, end_id),
        false_fn=lambda: tf.where(
            tf.tile(tf.expand_dims(finished, 1), [1, vocab_size]),
            x=eos_max_prob,
            y=logits))
    log_probs = tf.nn.log_softmax(logits)

    # Run one decoding strategy step.
    output, next_cum_log_probs, finished, state, extra_vars = decoding_strategy.step(
        step,
        sampler,
        log_probs,
        cum_log_probs,
        finished,
        state,
        extra_vars)

    # Update loop vars.
    if attention_history:
      if attn is None:
        raise ValueError("attention_history is set but the model did not return attention")
      attention = attention.write(step, tf.cast(attn, tf.float32))
    outputs = outputs.write(step, output)
    cum_log_probs = tf.where(finished, x=cum_log_probs, y=next_cum_log_probs)
    finished = tf.logical_or(finished, tf.equal(output, end_id))
    return step + 1, finished, state, output, outputs, attention, cum_log_probs, extra_vars

  batch_size = tf.shape(start_ids)[0]
  ids_dtype = start_ids.dtype
  start_ids = tf.cast(start_ids, tf.int32)
  start_ids, finished, initial_log_probs, extra_vars = decoding_strategy.initialize(
      batch_size, start_ids)
  step = tf.constant(0, dtype=tf.int32)
  outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
  attention = tf.TensorArray(tf.float32, size=0, dynamic_size=True)

  _, _, state, _, outputs, attention, log_probs, extra_vars = tf.while_loop(
      _cond,
      _body,
      loop_vars=(
          step,
          finished,
          initial_state,
          start_ids,
          outputs,
          attention,
          initial_log_probs,
          extra_vars),
      shape_invariants=(
          step.shape,
          finished.shape,
          compat.nest.map_structure(_get_shape_invariants, initial_state),
          start_ids.shape,
          tf.TensorShape(None),
          tf.TensorShape(None),
          initial_log_probs.shape,
          compat.nest.map_structure(_get_shape_invariants, extra_vars)),
      parallel_iterations=1,
      back_prop=False,
      maximum_iterations=maximum_iterations)

  ids, attention, lengths = decoding_strategy.finalize(
      outputs,
      end_id,
      extra_vars,
      attention=attention if attention_history else None)
  if attention is not None:
    attention = attention[:, :, 1:]  # Ignore attention for <s>.
  log_probs = tf.reshape(log_probs, [batch_size, decoding_strategy.num_hypotheses])
  ids = tf.cast(ids, ids_dtype)
  return ids, lengths, log_probs, attention, state
Beispiel #3
0
def dynamic_decode(symbols_to_logits_fn_trans,
                   symbols_to_logits_fn_ae,
                   trans_model_name,
                   ae_model_name,
                   start_ids,
                   end_id=constants.END_OF_SENTENCE_ID,
                   initial_state_trans=None,
                   initial_state_ae=None,
                   decoding_strategy=None,
                   sampler=None,
                   maximum_iterations=None,
                   minimum_iterations=0,
                   attention_history=False,
                   attention_size=None,
                   low_prob=None):
    """Dynamic decoding.

  Args:
    symbols_to_logits_fn: A callable taking ``(symbols, step, state)`` and
      returning ``(logits, state, attention)`` (``attention`` is optional).
    start_ids: Initial input IDs of shape :math:`[B]`.
    end_id: ID of the end of sequence token.
    initial_state: Initial decoder state.
    decoding_strategy: A :class:`opennmt.utils.decoding.DecodingStrategy`
      instance that define the decoding logic. Defaults to a greedy search.
    sampler: A :class:`opennmt.utils.decoding.Sampler` instance that samples
      predictions from the model output. Defaults to an argmax sampling.
    maximum_iterations: The maximum number of iterations to decode for.
    minimum_iterations: The minimum number of iterations to decode for.
    attention_history: Gather attention history during the decoding.
    attention_size: If known, the size of the attention vectors (i.e. the
      maximum source length).

  Returns:
    ids: The predicted ids of shape :math:`[B, H, T]`.
    lengths: The produced sequences length of shape :math:`[B, H]`.
    log_probs: The cumulated log probabilities of shape :math:`[B, H]`.
    attention_history: The attention history of shape :math:`[B, H, T_t, T_s]`.
    state: The final decoding state.
  """
    if "maximum_iterations" not in misc.function_args(tf.while_loop):
        raise NotImplementedError(
            "Unified decoding does not support TensorFlow 1.4. "
            "Please update your TensorFlow installation or open "
            "an issue for assistance.")
    if decoding_strategy is None:
        decoding_strategy = GreedySearch()
    if sampler is None:
        sampler = BestSampler()

    def _cond(step, finished, state_trans, state_ae, inputs, outputs,
              attention, cum_log_probs, extra_vars):  # pylint: disable=unused-argument
        return tf.reduce_any(tf.logical_not(finished))

    def _body(step, finished, state_trans, state_ae, inputs, outputs,
              attention, cum_log_probs, extra_vars):
        # Get log probs from the model.
        with tf.variable_scope(trans_model_name, reuse=tf.AUTO_REUSE):
            with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
                result_trans = symbols_to_logits_fn_trans(
                    inputs, step, state_trans)
        with tf.variable_scope(ae_model_name, reuse=tf.AUTO_REUSE):
            with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
                result_ae = symbols_to_logits_fn_ae(inputs, step, state_ae)
        logits_trans, state_trans = result_trans[0], result_trans[1]
        logits_ae, state_ae = result_ae[0], result_ae[1]
        logits_trans = tf.cast(logits_trans, tf.float32)
        logits_ae = tf.cast(logits_ae, tf.float32)
        logits = 1 * logits_trans + 0.5 * logits_ae
        if (low_prob != None):
            print("=============================================")
            print("lower some token's prob")
            logits = logits - tf.abs(logits * low_prob)
        print("=============================================")
        print(logits)

        # Penalize or force EOS.
        batch_size, vocab_size = misc.shape_list(logits_trans)
        eos_max_prob = tf.one_hot(tf.fill([batch_size], end_id),
                                  vocab_size,
                                  on_value=logits_trans.dtype.max,
                                  off_value=logits_trans.dtype.min)
        logits = tf.where(finished, x=eos_max_prob, y=logits)
        log_probs = tf.nn.log_softmax(logits)
        #logits_trans = tf.cond(
        #    step < minimum_iterations,
        #    true_fn=lambda: _penalize_token(logits_trans, end_id),
        #    false_fn=lambda: tf.where(finished, x=eos_max_prob, y=logits_trans))
        #logits_ae    = tf.cond(
        #    step < minimum_iterations,
        #    true_fn=lambda: _penalize_token(logits_ae   , end_id),
        #    false_fn=lambda: tf.where(finished, x=eos_max_prob, y=logits_ae   ))
        #log_probs_trans = tf.nn.log_softmax(logits_trans)
        #log_probs_ae    = tf.nn.log_softmax(logits_ae   )
        #log_probs = 1 * log_probs_trans + 0.2 * log_probs_ae

        # Run one decoding strategy step.
        print("============================================================")
        print(type(decoding_strategy))
        output, next_cum_log_probs, finished, state_trans, state_ae, extra_vars = decoding_strategy.step(
            step,
            sampler,
            log_probs,
            cum_log_probs,
            finished,
            state_trans,
            state_ae,
            extra_vars,
            attention=None)

        outputs = outputs.write(step, output)
        cum_log_probs = tf.where(finished,
                                 x=cum_log_probs,
                                 y=next_cum_log_probs)
        finished = tf.logical_or(finished, tf.equal(output, end_id))
        return step + 1, finished, state_trans, state_ae, output, outputs, attention, cum_log_probs, extra_vars

    batch_size = tf.shape(start_ids)[0]
    ids_dtype = start_ids.dtype
    start_ids = tf.cast(start_ids, tf.int32)
    start_ids, finished, initial_log_probs, extra_vars = decoding_strategy.initialize(
        batch_size, start_ids, attention_size=attention_size)
    step = tf.constant(0, dtype=tf.int32)
    outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
    attention = tf.TensorArray(tf.float32, size=0, dynamic_size=True)

    _, _, state_trans, state_ae, _, outputs, attention, log_probs, extra_vars = tf.while_loop(
        _cond,
        _body,
        loop_vars=(step, finished, initial_state_trans, initial_state_ae,
                   start_ids, outputs, attention, initial_log_probs,
                   extra_vars),
        shape_invariants=(step.shape, finished.shape,
                          compat.nest.map_structure(_get_shape_invariants,
                                                    initial_state_trans),
                          compat.nest.map_structure(_get_shape_invariants,
                                                    initial_state_ae),
                          start_ids.shape, tf.TensorShape(None),
                          tf.TensorShape(None), initial_log_probs.shape,
                          compat.nest.map_structure(_get_shape_invariants,
                                                    extra_vars)),
        parallel_iterations=1,
        back_prop=False,
        maximum_iterations=maximum_iterations)

    ids, attention, lengths = decoding_strategy.finalize(
        outputs,
        end_id,
        extra_vars,
        attention=attention if attention_history else None)
    if attention is not None:
        attention = attention[:, :, :-1]  # Ignore attention for </s>.
    log_probs = tf.reshape(log_probs,
                           [batch_size, decoding_strategy.num_hypotheses])
    ids = tf.cast(ids, ids_dtype)
    return ids, lengths