Example #1
0
def SequenceAppendToken(x, x_paddings, token, extend=False):
    """Appends <token> to sequence `x`.

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    token: The token to append (of type integer).
    extend: Whether to extend `x` along the length dimension, this must be true
      for any sequence length in `x` that is `x_len_max` or else an invalid
      sequence will be emitted.

  Returns:
    A tuple.
      - The new sequence, Tensor of shape [batch_size, x_len_max].
      - The new paddings, Tensor of shape [batch_size, x_len_max].
  """
    batch_size = py_utils.GetShape(x)[0]
    x_len = tf.to_int32(tf.round(tf.reduce_sum(1 - x_paddings, 1)))
    if extend:
        x = tf.pad(x, [[0, 0], [0, 1]])
    # Mask all invalid entries of `x` to 0.
    x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype)
    # Append the <token> based on `x_len`.
    x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1),
                       tf.cast(tf.fill([batch_size], token), x.dtype),
                       py_utils.GetShape(x))
    x_paddings = 1 - tf.sequence_mask(x_len + 1,
                                      py_utils.GetShape(x)[1],
                                      x_paddings.dtype)
    return x, x_paddings
 def Step(recurrent_theta, state0, inputs):
     """Computes one decoder step."""
     del inputs
     with tf.name_scope('single_sampler_step'):
         # Compute logits and states.
         bs_result, bs_state1 = pre_step_callback(
             recurrent_theta.theta,
             recurrent_theta.encoder_outputs,
             tf.expand_dims(state0.ids, 1),  # [batch, 1].
             state0.bs_state,
             num_hyps_per_beam=1)
         batch = tf.shape(bs_result.log_probs)[0]
         state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
         state1.logits = bs_result.log_probs
         # Sample ids from logits. [batch].
         state1.ids = tf.reshape(
             tf.random.stateless_multinomial(
                 state1.logits / p.temperature,
                 num_samples=1,
                 seed=tf.stack(
                     [recurrent_theta.random_seed, state0.timestep]),
                 output_dtype=state0.ids.dtype,
                 name='sample_next_id'), [batch])
         if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
             state1.ids = tf.where(
                 tf.logical_and(bs_result.is_last_chunk,
                                tf.equal(state1.ids, p.target_eoc_id)),
                 tf.fill(tf.shape(state1.ids), p.target_eos_id),
                 state1.ids)
         state1.bs_state = post_step_callback(
             recurrent_theta.theta, recurrent_theta.encoder_outputs,
             state1.ids, bs_state1)
     return state1, py_utils.NestedMap()
Example #3
0
def _KeepTopP(sorted_log_probs, p):
    """Keeps the top-p probability mass of `sorted_log_probs`.

  For each row, elements that are not included in the first `p` probability mass
  are set to `LARGE_NEGATIVE_NUMBER`. The first element is always kept as-is.

  Args:
    sorted_log_probs: A float tensor of shape [batch, k] that represents
      log-probabilities sorted in descending order. The probabilities do not
      need to sum to 1.
    p: A float tensor of shape [batch] that represents a probability threshold
      for each batch item.

  Returns:
    A tensor like `sorted_log_probs` where elements outside the top-p
    probability mass are set to `LARGE_NEGATIVE_NUMBER`.
  """
    sorted_cum_probs = tf.math.cumsum(tf.exp(sorted_log_probs),
                                      exclusive=True,
                                      axis=-1)
    mask = tf.less(sorted_cum_probs, tf.expand_dims(p, axis=1))
    # Set mask[:, 0] = True to always keep the first element.
    batch_size = tf.shape(mask)[0]
    true = tf.ones([batch_size, 1], dtype=tf.bool)
    mask = tf.concat([true, mask[:, 1:]], axis=1)
    filtered_sorted_log_probs = tf.where(
        mask, sorted_log_probs,
        tf.fill(
            tf.shape(sorted_log_probs),
            tf.constant(LARGE_NEGATIVE_NUMBER, dtype=sorted_log_probs.dtype)))
    return filtered_sorted_log_probs
Example #4
0
  def testBeamSearchDecodeBiased(self, bias, bias_only_if_consistent):
    dtype = tf.float32
    with self.session(use_gpu=True) as sess, self.SetEval(True):
      tf.random.set_seed(_TF_RANDOM_SEED)
      src_batch = 2
      p = self._DecoderParams(dtype=dtype)
      p.bias_only_if_consistent = bias_only_if_consistent
      p.target_seq_len = 6
      p.beam_search.num_hyps_per_beam = 2
      p.rnn_cell_dim = 32
      dec = p.Instantiate()
      encoder_outputs, _ = self._Inputs(dtype=dtype)
      encoder_outputs['targets'] = py_utils.NestedMap(
          labels=tf.constant([[1, 3, 0, 0], [3, 4, 5, 2]]),
          paddings=tf.constant([[0, 0, 1, 1], [0, 0, 0, 0]], dtype=dtype))
      encoder_outputs['targets']['weights'] = tf.fill(
          tf.shape(encoder_outputs.targets.labels), bias)
      decode = dec.BeamSearchDecodeBiased(encoder_outputs)

      # topk_decoded is None in MT decoder, set it to a fake tensor to pass
      # sess.run(decode).
      decode = decode._replace(topk_decoded=tf.constant(0, tf.float32))

      tf.global_variables_initializer().run()
      actual_decode = sess.run(decode)

    num_hyps = src_batch * p.beam_search.num_hyps_per_beam
    self.assertTupleEqual((p.target_seq_len, num_hyps),
                          actual_decode.done_hyps.shape)
    self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam),
                          actual_decode.topk_hyps.shape)
    self.assertTupleEqual((num_hyps, p.target_seq_len),
                          actual_decode.topk_ids.shape)
    self.assertTupleEqual((num_hyps,), actual_decode.topk_lens.shape)
    self.assertTupleEqual((src_batch, p.beam_search.num_hyps_per_beam),
                          actual_decode.topk_scores.shape)

    if bias == 0:
      expected_topk_ids = [[2, 0, 0, 0, 0, 0], [13, 2, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
      expected_topk_lens = [1, 2, 0, 0]
      expected_topk_scores = [[-3.783162, -5.767723], [0., 0.]]
    elif bias == 1 and bias_only_if_consistent:
      expected_topk_ids = [[1, 3, 2, 0, 0, 0], [1, 3, 13, 2, 0, 0],
                           [3, 4, 5, 2, 0, 0], [0, 0, 0, 0, 0, 0]]
      expected_topk_lens = [3, 4, 4, 0]
      expected_topk_scores = [[-3.073836, -5.474799], [-0.415888, 0.]]
    elif bias == 1 and (not bias_only_if_consistent):
      expected_topk_ids = [[1, 3, 2, 0, 0, 0], [1, 3, 13, 2, 0, 0],
                           [3, 4, 5, 2, 0, 0], [3, 4, 0, 2, 0, 0]]
      expected_topk_lens = [3, 4, 4, 4]
      expected_topk_scores = [[-3.073836, -5.474799], [-0.415888, -23.295631]]

    self.assertAllEqual(expected_topk_ids, actual_decode.topk_ids)
    self.assertAllEqual(expected_topk_lens, actual_decode.topk_lens)
    self.assertAllClose(expected_topk_scores, actual_decode.topk_scores)
Example #5
0
 def testMassLayer(self):
     with self.session(use_gpu=False) as sess:
         batch_size = 3
         seq_len = 10
         p = self._MassParams()
         mass_layer = data_augmenter.MASS(p)
         seq_ids = tf.fill([batch_size, seq_len], 4)
         weights = tf.ones([batch_size, seq_len])
         actual_seq_len = tf.fill([batch_size], 10)
         mass_out = mass_layer.Mask(seq_ids, weights, actual_seq_len)
         (src_ids, tgt_ids, tgt_labels, tgt_weights) = sess.run([
             mass_out.src.ids, mass_out.tgt.ids, mass_out.tgt.labels,
             mass_out.tgt.weights
         ])
         self.assertAllEqual(np.sum(src_ids == 3, axis=1), [5, 5, 5])
         self.assertAllEqual(np.sum(tgt_ids == 3, axis=1), [5, 5, 5])
         self.assertAllEqual(
             tgt_labels, 4 * np.ones([batch_size, seq_len], dtype=np.int32))
         self.assertAllEqual(np.sum(tgt_weights, axis=1), [5., 5., 5.])
Example #6
0
        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            if p.use_recurrent:
                del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=p.num_hyps_per_beam)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs

                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(state1.logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits
                else:
                    sample_logits = state1.logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            if p.use_recurrent:
                return state1, py_utils.NestedMap()
            else:
                inputs.ids = inputs.ids.write(state0.timestep, state1.ids)
                inputs.logits = inputs.logits.write(state0.timestep,
                                                    state1.logits)
                return (recurrent_theta, state1, inputs)
Example #7
0
    def forward(inputs, alpha):
        with tf.name_scope("entmax_loss"):
            alpha_shape = inputs.get_shape().as_list()

            alpha_shape[axis] = 1
            alpha = tf.fill(alpha_shape, alpha)
            alpha = tf.cast(alpha, dtype=inputs.dtype)

            d = inputs.get_shape().as_list()[axis]
            alpha_m1 = alpha - 1.0

            inputs = inputs * alpha_m1

            max_val = tf.math.reduce_max(inputs, axis=axis, keepdims=True)
            tau_lo = max_val - tf.ones(alpha.get_shape().as_list(),
                                       dtype=inputs.dtype)
            tau_hi = max_val - tf.math.pow(
                tf.cast((1.0 / d), dtype=inputs.dtype), alpha_m1)

            f_lo = tf.math.reduce_sum(
                _calculate_probability(tf.math.subtract(inputs, tau_lo),
                                       alpha), axis) - 1.0

            dm = tau_hi - tau_lo

            for _ in range(n_iter):
                dm /= 2
                tau_m = tau_lo + dm
                p_m = _calculate_probability(inputs - tau_m, alpha)
                f_m = tf.math.reduce_sum(p_m, axis) - 1.0

                mask = tf.expand_dims(tf.math.greater(f_m * f_lo, 0), axis)
                tau_lo = tf.where(mask, tau_m, tau_lo)

            if ensure_sum_one:
                p_m /= tf.expand_dims(tf.math.reduce_sum(p_m, axis), axis)

        def grad_fn(d_outputs):
            with tf.name_scope("entmax_grad"):
                gppr = tf.where(p_m > 0, tf.math.pow(p_m, 2.0 - alpha),
                                tf.zeros_like(p_m))
                d_inputs = d_outputs * gppr
                q = tf.math.reduce_sum(d_inputs, axis) / tf.math.reduce_sum(
                    gppr, axis)
                q = tf.expand_dims(q, axis)
                d_inputs -= q * gppr
                return d_inputs, d_inputs

        return p_m, grad_fn
 def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                               unused_step_ids, states,
                               unused_num_hyps_per_beam):
     # Same probs for each id.
     logits = tf.zeros([tgt_batch_size, vocab_size])
     # Except eoc has slightly lower score.
     logits = logits - 1.0 * tf.expand_dims(
         tf.one_hot(p.target_eoc_id, vocab_size), 0)
     # eos has very low score (can not terminate by eos)
     logits = logits + eos_score * tf.expand_dims(
         tf.one_hot(p.target_eos_id, vocab_size), 0)
     return py_utils.NestedMap(
         atten_probs=tf.zeros([tgt_batch_size, 0]),
         log_probs=logits,
         is_last_chunk=tf.fill([tgt_batch_size],
                               value=is_last_chunk)), states
Example #9
0
def FillPaddingPos(ids: tf.Tensor, id_len: tf.Tensor,
                   padding_value: int) -> tf.Tensor:
    """Given a batch of sequences, fills the padding pos with `padding_value`.

  Args:
    ids: a [B, max_len] int tensor.
    id_len: a [B, ] int tensor.
    padding_value: an int.

  Returns:
    new_ids: new ids with the property.
      - new_ids[b, :id_len[b]] = ids[b, :id_len[b]]
      - new_ids[b, id_len[b]:] = padding_value
  """
    mask = py_utils.SequencePaddings(id_len, maxlen=tf.shape(ids)[1])
    mask = tf.cast(mask, dtype=tf.bool)
    new_ids = tf.where(mask, tf.fill(tf.shape(ids), padding_value), ids)
    return new_ids
Example #10
0
    def _InitIterator(self):
        """Override of the root's _InitIterator to support dataset repeat."""
        if self.host_id in self._dataset:
            return

        p = self.params
        self._repeat_steps = getattr(self._input_generator.params,
                                     'repeat_steps', None)
        self._repeat_with_sentinel = getattr(self._input_generator.params,
                                             'repeat_with_sentinel', None)

        with py_utils.GlobalStepContext(None):
            # Hide global_step tensor from being captured by dataset function.
            ds = self.GetDataset()
        if self._repeat_steps:
            tf.logging.info('Repeating dataset every %d steps.',
                            self._repeat_steps)
            ds = ds.take(self._repeat_steps).repeat()
        elif self._repeat_with_sentinel:
            tf.logging.info('Attaching sentinel to end of dataset and repeat.')
            # Dataset should contain batches of type NestedMap.
            sentinel_batch = ds.element_spec.Transform(
                lambda x: tf.zeros(x.shape, dtype=x.dtype))
            # Fill the dummy sentinel batch's sentinel_key tensor with sentinel_value.
            sentinel_batch[p.sentinel_key] = tf.fill(
                sentinel_batch[p.sentinel_key].shape, p.sentinel_value)
            tf.logging.info('attaching sentinel %r',
                            sentinel_batch[p.sentinel_key])
            tf.logging.info('sentinel type %r',
                            sentinel_batch[p.sentinel_key].dtype)
            ds = ds.concatenate(
                tf.data.Dataset.from_tensors(sentinel_batch)).repeat()
        options = tf.data.Options()
        options.experimental_deterministic = bool(self.cluster.in_unit_test)
        ds = ds.with_options(options)
        self._dataset[self.host_id] = ds
        if tf.executing_eagerly_outside_functions():
            it = iter(ds)
        else:
            it = tf.data.make_initializable_iterator(ds)
        self._iterator[self.host_id] = it
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
  """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
  source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
  value_dict = {}
  for output in beam_search_outputs:
    hyps_per_beam = py_utils.with_dependencies([
        py_utils.assert_equal(source_batch,
                              tf.shape(output.topk_hyps)[0]),
    ],
                                               tf.shape(output.topk_hyps)[1])
    for k, v in six.iteritems(output._asdict()):
      if v is None:
        continue
      if k == 'done_hyps':
        v = tf.transpose(v)
      if k not in value_dict:
        value_dict[k] = []
      value_dict[k].append(tf.reshape(v, [source_batch, hyps_per_beam, -1]))

  # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
  concatenated = {}
  for k, values in six.iteritems(value_dict):
    if len(values) != len(beam_search_outputs):
      raise ValueError('Incomplete values for %s: %s' %
                       (k, beam_search_outputs))
    concatenated[k] = tf.concat(values, axis=1)

  scores = concatenated['topk_scores']
  scores = tf.where(
      tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6),
      scores)
  scores = tf.squeeze(scores, -1)

  # Select top max_hyps_per_beam indices per beam.
  _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
  batch_ids = tf.tile(
      tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam])
  # [source_batch, max_hyps_per_beam, 2]
  gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

  # Gather the merged top hyps according to 'gather_indices'.
  top = beam_search_outputs[0]._asdict()
  total_hyps = source_batch * max_hyps_per_beam
  for k, v in six.iteritems(concatenated):
    v = tf.gather_nd(v, gather_indices)
    if k == 'done_hyps':
      v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
    elif k == 'topk_hyps':
      v = tf.reshape(v, [source_batch, max_hyps_per_beam])
    elif k == 'topk_ids':
      v = tf.reshape(v, [total_hyps, -1])
    elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
      v = tf.reshape(v, [total_hyps])
    else:
      raise ValueError('Unexpected field: %s' % k)
    top[k] = v
  return BeamSearchDecodeOutput(**top)
Example #12
0
  def Sample(self, decoder_theta, encoder_outputs, random_seed,
             init_state_callback, pre_step_callback, post_step_callback):
    """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.

    Returns:
      A NestedMap containing the following tensors:
      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
    p = self.params
    assert p.temperature > 0
    # 'recurrent_theta' represents all cross-timestep information used by the
    # recurrent loop below, including layer theta and encoder outputs.
    recurrent_theta = py_utils.NestedMap(
        theta=decoder_theta,
        random_seed=random_seed,
        encoder_outputs=encoder_outputs)
    bs_result, bs_state = init_state_callback(
        recurrent_theta.theta, encoder_outputs, num_hyps_per_beam=1)
    batch = tf.shape(bs_result.log_probs)[0]
    recurrent_state0 = py_utils.NestedMap(
        timestep=tf.zeros(shape=[], dtype=tf.int32),
        logits=bs_result.log_probs,
        # Start with target_sos_id.
        ids=tf.fill([batch], tf.to_int32(p.target_sos_id)),
        bs_state=bs_state)
    inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch]))

    def Step(recurrent_theta, state0, inputs):
      """Computes one decoder step."""
      del inputs
      with tf.name_scope('single_sampler_step'):
        # Compute logits and states.
        bs_result, bs_state1 = pre_step_callback(
            recurrent_theta.theta,
            recurrent_theta.encoder_outputs,
            tf.expand_dims(state0.ids, 1),  # [batch, 1].
            state0.bs_state,
            num_hyps_per_beam=1)
        batch = tf.shape(bs_result.log_probs)[0]
        state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
        state1.logits = bs_result.log_probs
        # Sample ids from logits. [batch].
        state1.ids = tf.reshape(
            tf.random.stateless_multinomial(
                state1.logits / p.temperature,
                num_samples=1,
                seed=tf.stack([recurrent_theta.random_seed, state0.timestep]),
                output_dtype=state0.ids.dtype,
                name='sample_next_id'), [batch])
        if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
          state1.ids = tf.where(
              tf.logical_and(bs_result.is_last_chunk,
                             tf.equal(state1.ids, p.target_eoc_id)),
              tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids)
        state1.bs_state = post_step_callback(recurrent_theta.theta,
                                             recurrent_theta.encoder_outputs,
                                             state1.ids, bs_state1)
      return state1, py_utils.NestedMap()

    accumulated_states, _ = recurrent.Recurrent(recurrent_theta,
                                                recurrent_state0, inputs, Step)
    result = py_utils.NestedMap(
        logits=tf.transpose(accumulated_states.logits, [1, 0, 2]),
        ids=tf.transpose(accumulated_states.ids))
    result.paddings = tf.cast(
        _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
    # Force ids to be eos_id if the timestep is padded.
    result.ids = tf.where(
        tf.equal(result.paddings, 0), result.ids,
        tf.fill(tf.shape(result.ids), p.target_eos_id))
    static_batch_size = bs_result.log_probs.shape[0]
    result.ids.set_shape([static_batch_size, p.target_seq_len])
    result.paddings.set_shape([static_batch_size, p.target_seq_len])
    return result
  def BeamSearchDecode(self,
                       theta,
                       encoder_outputs,
                       num_hyps_per_beam_override=0,
                       init_beam_search_state=None,
                       pre_beam_search_step_callback=None,
                       post_beam_search_step_callback=None,
                       max_steps=None):
    """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
    p = self.params
    num_hyps_per_beam = p.num_hyps_per_beam
    if num_hyps_per_beam_override > 0:
      num_hyps_per_beam = num_hyps_per_beam_override
    if max_steps is None:
      max_steps = p.target_seq_len

    initial_results, other_states = init_beam_search_state(
        theta, encoder_outputs, num_hyps_per_beam)

    num_hyps = tf.shape(initial_results.log_probs)[0]
    num_beams = num_hyps // num_hyps_per_beam

    if 'step_ids' in initial_results:
      # [num_hyps, 1]
      step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
    else:
      step_ids = tf.fill([num_hyps, 1],
                         tf.constant(p.target_sos_id, dtype=tf.int32))

    min_score = -1e36
    best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
    cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
    in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
    in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
    in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
    in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
    bs_atten_probs = tf.zeros(
        [max_steps, num_hyps,
         tf.shape(initial_results.atten_probs)[1]],
        dtype=p.dtype)
    cur_step = tf.constant(0, dtype=tf.int32)
    all_done = tf.constant(False, dtype=tf.bool)
    core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                      in_prev_hyps, in_done_hyps, bs_atten_probs)

    def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states,
                     unused_other_states_list):
      return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done))

    def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                 other_states_list):
      (cur_step, all_done, new_step_ids, new_bs_states,
       new_other_states) = self._BeamSearchStep(
           theta, encoder_outputs, cur_step, step_ids, core_bs_states,
           other_states.Pack(other_states_list), num_hyps_per_beam,
           pre_beam_search_step_callback, post_beam_search_step_callback)
      return (cur_step, all_done, new_step_ids, new_bs_states,
              new_other_states.Flatten())

    flat_other_states = other_states.Flatten()
    _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
        LoopContinue,
        LoopBody,
        loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                   flat_other_states),
        parallel_iterations=10,
        back_prop=False,
        swap_memory=False,
        shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                          tf.TensorShape(all_done.get_shape()),
                          tf.TensorShape(step_ids.get_shape()),
                          _GetShapes(core_bs_states),
                          _GetShapes(flat_other_states, none_shapes=True)))
    # [target_seq_len, num_beams * num_hyps_per_beam].
    final_done_hyps = final_bs_states[5]
    final_other_states = other_states.Pack(flat_final_other_states)

    # TODO(rpang): avoid inspecting 'encoder_outputs'.
    source_paddings = encoder_outputs.padding
    if isinstance(source_paddings, py_utils.NestedMap):
      source_seq_lengths = tf.cast(
          tf.round(
              tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
    else:
      source_seq_lengths = tf.cast(
          tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
          tf.int32)

    # [num_beams, num_hyps_per_beam].
    topk_hyps = ops.top_k_terminated_hyps(
        final_done_hyps,
        source_seq_lengths,
        k=num_hyps_per_beam,
        num_hyps_per_beam=num_hyps_per_beam,
        length_normalization=p.length_normalization,
        coverage_penalty=p.coverage_penalty,
        target_seq_length_ratio=p.target_seq_length_ratio,
        eoc_id=p.target_eoc_id,
        merge_paths=p.merge_paths)
    # [num_beams * num_hyps_per_beam, ...].
    max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
    topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
        tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
    # [num_beams, num_hyps_per_beam].
    topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

    return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                  topk_lens, topk_scores, None,
                                  final_other_states)
Example #14
0
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
Example #15
0
    def _AddNoise(self, batch):
        """Adding noise the src (see https://arxiv.org/pdf/1711.00043).

    This function implement 3 types of noise (hyparams defined in
    self.params.denoise):
    1) slightly shuffle the sentence following p.shuffle_tok_range
    2) randomly drop tokens with probability p.drop_tok_prob
    3) randomly mask tokens with probability p.blank_tok_prob
    The noises are added to the input with probability p.noise_sent_prob.

    Args:
      batch: a `.NestedMap` of the input batch.
    """
        def IsSpecialExample(task_ids, special_task_ids):
            """A utility function indicates whether inputs belong to specific tasks.

      Args:
        task_ids: Task ids for the input batch. Tensor of shape [batch].
        special_task_ids: A list of specified task ids.

      Returns:
        A tensor indicating whether each sample in the batch belong to the
        specified task. Return a tensor of size [batch].
      """
            batch_size = py_utils.GetShape(task_ids)[0]
            return tf.reduce_any(
                tf.equal(
                    tf.expand_dims(task_ids, -1),
                    tf.cast(
                        tf.broadcast_to(
                            special_task_ids,
                            [batch_size, len(special_task_ids)]), tf.int32)),
                -1)

        p = self.params.denoise
        batch_size = tf.shape(batch.src.ids)[0]
        source_max_len = tf.shape(batch.src.ids)[1]

        # Shuffle tokens according to p.shuffle_tok_range
        noise = tf.random.uniform([batch_size, source_max_len], 0,
                                  p.shuffle_tok_range + 1)

        # Don't shuffle eos or padding
        shuffle_tok_range = tf.fill([batch_size, source_max_len],
                                    float(p.shuffle_tok_range))
        shifted_paddings = tf.pad(batch.src.paddings[:, 1:], [[0, 0], [0, 1]],
                                  constant_values=1)
        noise = tf.where(tf.equal(shifted_paddings, 0), noise,
                         shuffle_tok_range)
        indices = tf.broadcast_to(tf.range(source_max_len, dtype=tf.int32),
                                  [batch_size, source_max_len])
        noisy_indices = tf.cast(indices, dtype=tf.float32) + noise
        permutations = tf.argsort(noisy_indices)
        stacked = tf.stack([batch.src.ids, permutations], axis=1)
        denoise_src_ids = tf.stack(tf.map_fn(lambda x: tf.gather(x[0], x[1]),
                                             stacked),
                                   axis=0)

        # Select tokens to drop with probability=p.drop_tok_prob
        random_drop_tok = tf.random.uniform([batch_size, source_max_len])
        # Don't drop eos token
        is_keep_tok = tf.math.logical_or(
            tf.greater(random_drop_tok, p.drop_tok_prob),
            tf.equal(denoise_src_ids, self._src_tokenizer.eos_id))
        denoise_src_ids = tf.ragged.boolean_mask(
            denoise_src_ids,
            is_keep_tok).to_tensor(default_value=0,
                                   shape=tf.shape(batch.src.ids))
        denoise_src_paddings = tf.ragged.boolean_mask(
            batch.src.paddings,
            is_keep_tok).to_tensor(default_value=1,
                                   shape=tf.shape(batch.src.ids))

        # Select tokens to blank with probability=p.blank_tok_prob
        # Don't blank eos token
        random_blank_tok = tf.random.uniform([batch_size, source_max_len])
        shifted_paddings = tf.pad(denoise_src_paddings[:, 1:],
                                  [[0, 0], [0, 1]],
                                  constant_values=1)
        is_blank_tok = tf.math.logical_and(
            tf.less(random_blank_tok, p.blank_tok_prob),
            tf.equal(shifted_paddings, 0))
        blank_id = tf.fill([batch_size, source_max_len], p.blank_id)
        denoise_src_ids = tf.where(is_blank_tok, blank_id, denoise_src_ids)

        # Select denoising task examples with probability=p.denoise_sent_prob
        random_uniform_sent = tf.random.uniform([batch_size])
        is_denoise_sent = tf.math.logical_and(
            tf.less(random_uniform_sent, p.noise_sent_prob),
            IsSpecialExample(self._GetTaskIds(batch.src.source_ids[:, 0]),
                             p.task_ids))
        batch.src.ids = tf.where(is_denoise_sent, denoise_src_ids,
                                 batch.src.ids)
        batch.src.paddings = tf.where(is_denoise_sent, denoise_src_paddings,
                                      batch.src.paddings)
        batch.src.ids_indicator = 1 - batch.src.paddings
        batch.src.weights = batch.src.ids_indicator
Example #16
0
    def GreedySearchDecode(self,
                           theta,
                           encoder_outputs,
                           init_beam_search_state=None,
                           pre_beam_search_step_callback=None,
                           post_beam_search_step_callback=None,
                           max_steps=None):
        """Performs greedy-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as
      src_batch_size.

        - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos>
          token is encountered during search.
        - hyp_lens: [num_hyps].
        - done_hyps: [num_hyps], whether or not an eos is encountered.
    """
        p = self.params
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta,
            encoder_outputs,
            1  # num_hyps_per_beam
        )

        num_hyps = tf.shape(initial_results.log_probs)[0]

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        cur_step = tf.constant(0, dtype=tf.int32)
        done_hyps = inplace_ops.empty(shape=[num_hyps],
                                      dtype=tf.bool,
                                      init=True,
                                      name='done_hyps')
        hyp_lens = inplace_ops.empty(shape=[num_hyps],
                                     dtype=tf.int32,
                                     init=True,
                                     name='hyp_lens')
        hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps],
                                    dtype=tf.int32,
                                    init=True,
                                    name='hyp_ids')

        def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                         unused_hyp_lens, done_hyps, unused_other_states_list):
            return tf.logical_and(cur_step < max_steps,
                                  tf.logical_not(tf.reduce_all(done_hyps)))

        def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                     other_states_list):
            (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
             new_other_states) = self._GreedySearchStep(
                 theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens,
                 done_hyps, other_states.Pack(other_states_list),
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              tf.TensorShape(hyp_ids.get_shape()),
                              tf.TensorShape(hyp_lens.get_shape()),
                              tf.TensorShape(done_hyps.get_shape()),
                              _GetShapes(flat_other_states, none_shapes=True)))

        # transpose hyp_ids so it matches BeamSearchDecode's output
        final_hyp_ids = tf.transpose(final_hyp_ids)
        return final_hyp_ids, final_hyp_lens, final_done_hyps
Example #17
0
                def PerturbedLogProbs():
                    # STEP 1: Perform top-k filtering. This is done as a performance
                    # optimization of avoiding sorting the entire `log_probs`, which is
                    # prohibitively slow.
                    top_k = tf.math.top_k(log_probs, k, sorted=True)
                    # shape: [tgt_batch, k]
                    top_k_log_probs = top_k.values
                    # shape: [tgt_batch, k]
                    top_k_ids = top_k.indices

                    # STEP 2: Perform top-p filtering.
                    # shape: [tgt_batch]
                    top_p_threshold = encoder_outputs.stochastic_beam_search.top_p_threshold
                    top_p_threshold = tf.clip_by_value(top_p_threshold, 0., 1.)
                    top_p_threshold = TileForBeamAndFlatten(top_p_threshold)
                    # shape: [tgt_batch, k]
                    filtered_top_k_log_probs = _KeepTopP(
                        top_k_log_probs, top_p_threshold)

                    # STEP 3: Perturb cumulative log-probs.
                    # shape: [tgt_batch, 1]
                    last_cumulative_log_probs = states.cumulative_log_probs
                    # shape: [tgt_batch, 1]
                    last_perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs
                    # Compute cumulative log-probs of the current step.
                    # shape: [tgt_batch, k]
                    cumulative_log_probs = (last_cumulative_log_probs +
                                            filtered_top_k_log_probs)
                    # Perturb cumulative log-probs by Gumbel noises under the condition
                    # that the max of the new perturbed log-probs is equal to
                    # perturbed_cumulative_log_probs of the previous step.
                    # shape: [tgt_batch, k]
                    new_perturbed_cumulative_log_probs = _SampleGumbelWithMax(
                        cumulative_log_probs,
                        last_perturbed_cumulative_log_probs,
                        encoder_outputs.stochastic_beam_search.seed, time_step,
                        encoder_outputs.stochastic_beam_search.src_ids,
                        encoder_outputs.stochastic_beam_search.src_paddings)

                    # STEP 4: Compute updated log_probs. This step is necessary because
                    # the output of PreBeamSearchStepCallback must be "per-step"
                    # log-probs, whereas so far "cumulative" log-probs have been computed.
                    # shape: [tgt_batch, k]
                    updated_top_k_log_probs = (
                        new_perturbed_cumulative_log_probs -
                        last_perturbed_cumulative_log_probs)
                    # Convert to the shape [tgt_batch, vocab_size].
                    updated_log_probs = tf.fill(
                        tf.shape(log_probs),
                        tf.constant(LARGE_NEGATIVE_NUMBER,
                                    dtype=log_probs.dtype))
                    updated_log_probs = _BatchScatter(updated_log_probs,
                                                      top_k_ids,
                                                      updated_top_k_log_probs)

                    return (updated_log_probs,
                            py_utils.NestedMap(
                                new_perturbed_cumulative_log_probs=
                                new_perturbed_cumulative_log_probs,
                                top_k_log_probs=top_k_log_probs,
                                top_k_ids=top_k_ids,
                            ))
Example #18
0
 def FProp(self, _, inputs):
     return tf.fill(tf.shape(inputs), 42)
        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            if p.use_recurrent:
                del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    p.num_hyps_per_beam,
                    0)  # cur_step
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs
                sample_logits = state1.logits
                # Perform Nucleus Sampling. Assumes logits are in (-1e10, 1e3).
                if p.nucleus_p < 1.0:
                    max_logit = 1e3
                    min_logit = -1e10
                    sorted_logits = tf.sort(sample_logits,
                                            direction='DESCENDING',
                                            axis=-1)
                    sorted_probs = tf.nn.softmax(sorted_logits)
                    cumsum_probs = tf.math.cumsum(sorted_probs,
                                                  axis=-1,
                                                  exclusive=True)
                    masked_logits = tf.where(
                        cumsum_probs < p.nucleus_p, sorted_logits,
                        tf.ones_like(sorted_logits) * max_logit)
                    threshold = tf.math.reduce_min(masked_logits,
                                                   axis=-1,
                                                   keepdims=True)
                    sample_logits = tf.where(
                        sample_logits < threshold,
                        tf.ones_like(sorted_logits) * min_logit, sample_logits)
                # Note that here, we retain the possibility of applying both top_k
                # and nucleus filtering.
                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(sample_logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            if p.use_recurrent:
                return state1, py_utils.NestedMap()
            else:
                inputs.ids = inputs.ids.write(state0.timestep, state1.ids)
                inputs.logits = inputs.logits.write(state0.timestep,
                                                    state1.logits)
                return (recurrent_theta, state1, inputs)
Example #20
0
    def Sample(self,
               decoder_theta,
               encoder_outputs,
               random_seed,
               init_state_callback,
               pre_step_callback,
               post_step_callback,
               init_step_ids=None):
        """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.
      init_step_ids: [batch], optional init step ids, default to SOS.

    Returns:
      A NestedMap containing the following tensors

      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
        p = self.params
        assert p.temperature > 0
        assert p.top_k >= 0
        assert p.num_hyps_per_beam >= 1
        if getattr(encoder_outputs, 'segment_id', 1) is None:
            # Remove None values, which are not supported by recurrent.
            del encoder_outputs['segment_id']
        # init_state_callback may modify 'encoder_outputs', e.g., by inserting
        # 'packed_src'.
        bs_result, bs_state = init_state_callback(decoder_theta,
                                                  encoder_outputs,
                                                  p.num_hyps_per_beam)
        # 'recurrent_theta' represents all cross-timestep information used by the
        # recurrent loop below, including layer theta and encoder outputs.
        recurrent_theta = py_utils.NestedMap(random_seed=random_seed,
                                             encoder_outputs=encoder_outputs)
        batch = tf.shape(bs_result.log_probs)[0]
        recurrent_state0 = py_utils.NestedMap(
            timestep=tf.zeros(shape=[], dtype=tf.int32),
            logits=bs_result.log_probs,
            # Start with target_sos_id.
            ids=init_step_ids if init_step_ids is not None else tf.fill(
                [batch], tf.cast(p.target_sos_id, tf.int32)),
            bs_state=bs_state)

        if p.use_recurrent:
            inputs = py_utils.NestedMap(
                dummy=tf.zeros([p.target_seq_len, batch]))
        else:
            inputs = py_utils.NestedMap(
                ids=tf.TensorArray(dtype=tf.int32, size=p.target_seq_len),
                logits=tf.TensorArray(dtype=bs_result.log_probs.dtype,
                                      size=p.target_seq_len),
            )

        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            if p.use_recurrent:
                del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=p.num_hyps_per_beam)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs

                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(state1.logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits
                else:
                    sample_logits = state1.logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            if p.use_recurrent:
                return state1, py_utils.NestedMap()
            else:
                inputs.ids = inputs.ids.write(state0.timestep, state1.ids)
                inputs.logits = inputs.logits.write(state0.timestep,
                                                    state1.logits)
                return (recurrent_theta, state1, inputs)

        if p.use_recurrent:

            def StopFn(t, theta, state):
                del t, theta  # Unused: this stop function only uses the state ids.
                return tf.equal(state.ids, p.target_eos_id)
        else:

            def StopFn(recurrent_theta, state, inputs):
                del recurrent_theta, inputs
                return tf.logical_not(
                    tf.reduce_all(tf.equal(state.ids, p.target_eos_id)))

        if p.use_stop_fn:
            stop_fn = StopFn
        else:
            stop_fn = None

        if p.use_recurrent:
            accumulated_states, _ = recurrent.Recurrent(
                recurrent_theta,
                recurrent_state0,
                inputs,
                Step,
                stop_fn=stop_fn,
                allow_implicit_capture=True)
        else:
            loop_vars = (recurrent_theta, recurrent_state0, inputs)
            (_, _, accumulated_states) = tf.while_loop(
                StopFn,
                Step,
                loop_vars=loop_vars,
                shape_invariants=_GetShapes(loop_vars, none_shapes=True),
                back_prop=False,
                maximum_iterations=p.target_seq_len)
            accumulated_states.ids = accumulated_states.ids.stack()
            accumulated_states.logits = accumulated_states.logits.stack()

        result = py_utils.NestedMap(logits=tf.transpose(
            accumulated_states.logits, [1, 0, 2]),
                                    ids=tf.transpose(accumulated_states.ids))
        result.paddings = tf.cast(
            _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
        # Force ids to be eos_id if the timestep is padded.
        result.ids = tf.where(tf.equal(result.paddings, 0), result.ids,
                              tf.fill(tf.shape(result.ids), p.target_eos_id))
        static_batch_size = bs_result.log_probs.shape[0]
        result.ids.set_shape([static_batch_size, p.target_seq_len])
        result.paddings.set_shape([static_batch_size, p.target_seq_len])
        return result