예제 #1
0
  def _body(step, finished, inputs, lengths, log_probs, state):
    logits, state = symbols_to_logits_fn(inputs, step, state)
    probs = tf.nn.log_softmax(tf.to_float(logits))
    if min_decode_length > 0:
      probs = tf.cond(
          step < min_decode_length,
          true_fn=lambda: beam_search.penalize_token(probs, end_id),
          false_fn=lambda: probs)

    # Sample best prediction.
    sample_ids = tf.argmax(probs, axis=-1, output_type=inputs.dtype)
    sample_probs = tf.reduce_max(probs, axis=-1)

    # Don't update finished batches.
    masked_lengths_inc = 1 - tf.cast(finished, lengths.dtype)
    masked_probs = sample_probs * (1.0 - tf.cast(finished, sample_probs.dtype))

    step = step + 1
    inputs = tf.concat([inputs, tf.expand_dims(sample_ids, 1)], -1)
    lengths += masked_lengths_inc
    log_probs += masked_probs
    finished = tf.logical_or(finished, tf.equal(sample_ids, end_id))
    if decode_length is not None:
      finished = tf.logical_or(finished, step >= decode_length)

    return step, finished, inputs, lengths, log_probs, state
예제 #2
0
    def _body(step, finished, inputs, outputs, lengths, cum_log_probs, state):
        # Run next step.
        logits, state = symbols_to_logits_fn(inputs, step, state)
        log_probs = tf.nn.log_softmax(tf.cast(logits, tf.float32))
        if min_decode_length > 0:
            log_probs = tf.cond(
                step < min_decode_length,
                true_fn=lambda: beam_search.penalize_token(log_probs, end_id),
                false_fn=lambda: log_probs)

        # Sample best prediction.
        sampled_log_probs, sampled_ids = tf.nn.top_k(log_probs, k=1)
        sampled_log_probs = tf.squeeze(sampled_log_probs, axis=1)
        sampled_ids = tf.squeeze(sampled_ids, axis=1)
        outputs = outputs.write(step, sampled_ids)

        # Don't update finished batches.
        lengths += 1 - tf.cast(finished, lengths.dtype)
        cum_log_probs += sampled_log_probs * (
            1.0 - tf.cast(finished, sampled_log_probs.dtype))
        finished = tf.logical_or(finished, tf.equal(sampled_ids, end_id))
        if last_step_as_input:
            next_inputs = sampled_ids
        else:
            next_inputs = tf.concat(
                [inputs, tf.expand_dims(sampled_ids, 1)], axis=1)
        return step + 1, finished, next_inputs, outputs, lengths, cum_log_probs, state
예제 #3
0
 def testPenalizeToken(self):
     log_probs = tf.zeros([4, 6])
     token_id = 1
     log_probs = beam_search.penalize_token(log_probs, token_id)
     log_probs = self.evaluate(log_probs)
     self.assertTrue(np.all(log_probs[:, token_id] < 0))
     non_penalized = np.delete(log_probs, 1, token_id)
     self.assertEqual(np.sum(non_penalized), 0)
예제 #4
0
    def _body(step, finished, inputs, outputs, lengths, cum_log_probs, state):
        # Run next step.
        logits, state = symbols_to_logits_fn(inputs, step, state)
        logits = tf.cast(logits, tf.float32)
        if sample_temperature != 1:
            logits /= tf.cast(sample_temperature, logits.dtype)
        log_probs = tf.nn.log_softmax(logits)
        if min_decode_length > 0:
            log_probs = tf.cond(
                step < min_decode_length,
                true_fn=lambda: beam_search.penalize_token(log_probs, end_id),
                false_fn=lambda: log_probs)

        if sample_from == 1:  # Sample best prediction.
            sampled_ids = tf.argmax(log_probs,
                                    axis=-1,
                                    output_type=inputs.dtype)
        elif sample_from == 0:  # Sample from the full output distribution.
            distribution = tf.distributions.Categorical(
                probs=tf.exp(log_probs), dtype=inputs.dtype)
            sampled_ids = distribution.sample()
        else:  # Sample from the top K.
            topk_log_probs, topk_ids = tf.nn.top_k(log_probs, k=sample_from)
            topk_ids = tf.cast(topk_ids, inputs.dtype)
            distribution = tf.distributions.Categorical(logits=topk_log_probs,
                                                        dtype=inputs.dtype)
            topk_sampled_ids = distribution.sample()
            sampled_ids = tf.gather_nd(
                topk_ids, tf.stack([batch_ids, topk_sampled_ids], axis=-1))

        sampled_log_probs = tf.gather_nd(
            log_probs, tf.stack([batch_ids, sampled_ids], axis=-1))
        outputs = outputs.write(step, sampled_ids)

        # Don't update finished batches.
        lengths += 1 - tf.cast(finished, lengths.dtype)
        cum_log_probs += sampled_log_probs * (
            1.0 - tf.cast(finished, sampled_log_probs.dtype))
        finished = tf.logical_or(finished, tf.equal(sampled_ids, end_id))
        if last_step_as_input:
            next_inputs = sampled_ids
        else:
            next_inputs = tf.concat(
                [inputs, tf.expand_dims(sampled_ids, 1)], axis=1)
        return step + 1, finished, next_inputs, outputs, lengths, cum_log_probs, state