Ejemplo n.º 1
0
    def BeamSearchDecode(self,
                         theta,
                         encoder_outputs,
                         num_hyps_per_beam_override=0,
                         init_beam_search_state=None,
                         pre_beam_search_step_callback=None,
                         post_beam_search_step_callback=None,
                         max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
        p = self.params
        num_hyps_per_beam = p.num_hyps_per_beam
        if num_hyps_per_beam_override > 0:
            num_hyps_per_beam = num_hyps_per_beam_override
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        min_score = -1e36
        best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
        in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
        in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
        bs_atten_probs = tf.zeros(
            [max_steps, num_hyps,
             tf.shape(initial_results.atten_probs)[1]],
            dtype=p.dtype)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                          in_prev_hyps, in_done_hyps, bs_atten_probs)

        def LoopContinue(cur_step, all_done, unused_step_ids,
                         unused_core_bs_states, unused_other_states_list):
            return tf.logical_and(cur_step < max_steps,
                                  tf.logical_not(all_done))

        def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                     other_states_list):
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta, encoder_outputs, cur_step, step_ids, core_bs_states,
                 other_states.Pack(other_states_list), num_hyps_per_beam,
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(all_done.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              _GetShapes(core_bs_states),
                              _GetShapes(flat_other_states, none_shapes=True)))
        # [target_seq_len, num_beams * num_hyps_per_beam].
        final_done_hyps = final_bs_states[5]
        final_other_states = other_states.Pack(flat_final_other_states)

        # TODO(rpang): avoid inspecting 'encoder_outputs'.
        source_paddings = encoder_outputs.padding
        if isinstance(source_paddings, py_utils.NestedMap):
            source_seq_lengths = tf.cast(
                tf.round(
                    tf.reduce_sum(
                        1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)),
                tf.int32)
        else:
            source_seq_lengths = tf.cast(
                tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings),
                                       1)), tf.int32)

        # [num_beams, num_hyps_per_beam].
        topk_hyps = ops.top_k_terminated_hyps(
            final_done_hyps,
            source_seq_lengths,
            k=num_hyps_per_beam,
            num_hyps_per_beam=num_hyps_per_beam,
            length_normalization=p.length_normalization,
            coverage_penalty=p.coverage_penalty,
            target_seq_length_ratio=p.target_seq_length_ratio,
            eoc_id=p.target_eoc_id,
            merge_paths=p.merge_paths)
        # [num_beams * num_hyps_per_beam, ...].
        max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
        topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
            tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
        # [num_beams, num_hyps_per_beam].
        topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

        return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                      topk_lens, topk_scores, None,
                                      final_other_states)
Ejemplo n.º 2
0
    def GreedySearchDecode(self,
                           theta,
                           encoder_outputs,
                           init_beam_search_state=None,
                           pre_beam_search_step_callback=None,
                           post_beam_search_step_callback=None,
                           max_steps=None):
        """Performs greedy-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as
      src_batch_size.

        - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos>
          token is encountered during search.
        - hyp_lens: [num_hyps].
        - done_hyps: [num_hyps], whether or not an eos is encountered.
    """
        p = self.params
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta,
            encoder_outputs,
            1  # num_hyps_per_beam
        )

        num_hyps = tf.shape(initial_results.log_probs)[0]

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        cur_step = tf.constant(0, dtype=tf.int32)
        done_hyps = inplace_ops.empty(shape=[num_hyps],
                                      dtype=tf.bool,
                                      init=True,
                                      name='done_hyps')
        hyp_lens = inplace_ops.empty(shape=[num_hyps],
                                     dtype=tf.int32,
                                     init=True,
                                     name='hyp_lens')
        hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps],
                                    dtype=tf.int32,
                                    init=True,
                                    name='hyp_ids')

        def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                         unused_hyp_lens, done_hyps, unused_other_states_list):
            return tf.logical_and(cur_step < max_steps,
                                  tf.logical_not(tf.reduce_all(done_hyps)))

        def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                     other_states_list):
            (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
             new_other_states) = self._GreedySearchStep(
                 theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens,
                 done_hyps, other_states.Pack(other_states_list),
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              tf.TensorShape(hyp_ids.get_shape()),
                              tf.TensorShape(hyp_lens.get_shape()),
                              tf.TensorShape(done_hyps.get_shape()),
                              _GetShapes(flat_other_states, none_shapes=True)))

        # transpose hyp_ids so it matches BeamSearchDecode's output
        final_hyp_ids = tf.transpose(final_hyp_ids)
        return final_hyp_ids, final_hyp_lens, final_done_hyps
Ejemplo n.º 3
0
    def _updated_statistics(self, var, partitioned_grads):
        """Returns updated Shampoo statistics L_t, R_t, etc.

    Args:
      var: tf.Variable associated with the gradient.
      partitioned_grads: Partitioned gradient tensor.

    Returns:
      A list of updated statistics matrices.
    """
        precond_statistics_update = []
        num_partitions = len(partitioned_grads)
        mat_stats = []
        mat_grads = []
        mat_dims = []
        for pt_idx, pt_grad in enumerate(partitioned_grads):
            pt_shape = pt_grad.get_shape()
            preconditioner_exists_for_dim = (
                self._preconditioner_available_for_dims(pt_shape))
            rank = len(pt_shape)
            # Calculates the preconditioner statistics for each tensor.
            for i in range(rank):
                if preconditioner_exists_for_dim[i]:
                    mat_stats.append(
                        self.get_slot(
                            var,
                            self._statistics_key_for_partition_and_dim(
                                i, pt_idx, num_partitions)))
                    mat_grads.append(pt_grad)
                    mat_dims.append(i)

        # axes is the list of indices to reduce - everything but
        # the current i.
        def _update_statistics(dim, stat_var, grad):
            """Update preconditioner statistics."""
            with tf.name_scope("GradientStatistics"):
                var_rank = len(grad.get_shape())
                axes = list(range(dim)) + list(range(dim + 1, var_rank))
                new_stat = math_ops.tensordot(grad, grad, axes=(axes, axes))
                if self._second_moment_averaging == 1.0:
                    updated_stat = state_ops.assign_add(stat_var, new_stat)
                else:
                    updated_stat = state_ops.assign_add(
                        stat_var,
                        (self._second_moment_averaging - 1.0) * stat_var +
                        (1.0 - self._second_moment_averaging) * new_stat)
                return updated_stat

        if self._statistics_computation_frequency <= 1:
            for mat_stat, mat_grad, dim in zip(mat_stats, mat_grads, mat_dims):
                precond_statistics_update.append(
                    _update_statistics(dim, mat_stat, mat_grad))
        else:

            # NOTE: We rewrite tf.cond() as a while loop to avoid certain overheads
            # in XLA from buffer allocation.
            def _loop_body(mat_stats, mat_grads, mat_dims,
                           unused_perform_step):
                precond_statistics_update_ops = []
                for mat_stat, mat_grad, dim in zip(mat_stats, mat_grads,
                                                   mat_dims):
                    precond_statistics_update_ops.append(
                        _update_statistics(dim, mat_stat, mat_grad))
                with tf.control_dependencies(precond_statistics_update_ops):
                    return tf.constant(False)

            loop_body_fn = functools.partial(_loop_body, mat_stats, mat_grads,
                                             mat_dims)
            precond_statistics_update.append(
                tf.while_loop(lambda perform_step: perform_step, loop_body_fn,
                              [self._run_statistics_computation]))

        return precond_statistics_update
Ejemplo n.º 4
0
    def _OutfeedDequeueLoop(self, per_example_tensors, num_loops, num_devices):
        """Process all per-example tensor outfeed data for a TPU sess.run.

    Args:
      per_example_tensors: dict of key -> tensor as generated by TpuTrainStep.
      num_loops: number of times that TpuTrainStep will be executed by TpuTrain.
      num_devices: number of TPU cores assigned to this process.

    Returns:
      A dict of per-example tensors from the latest TpuTrainStep.
    """
        if not per_example_tensors:
            return tf.no_op()

        tensor_shapes = [
            py_utils.GetShape(per_example_tensors[key])
            for key in sorted(per_example_tensors)
        ]
        tensor_types = [
            tf.as_dtype(per_example_tensors[key].dtype)
            for key in sorted(per_example_tensors)
        ]

        def LoopBody(i, *input_arrays):
            """Process outfeed data for a single TpuTrainStep.

      Args:
        i: current loop index.
        *input_arrays: One tf.TensorArray per outfeed tensor.

      Returns:
        i+1 (new index) plus post-write tf.TensorArray handles.
      """
            # Outfeed ops execute on each JF node, so they must be located on the
            # nodes.
            outfeed_devices = []
            device_assignment = py_utils.GetTpuDeviceAssignment()
            assert device_assignment
            for replica in xrange(device_assignment.num_replicas):
                for core in xrange(device_assignment.num_cores_per_replica):
                    with tf.device(device_assignment.host_device(
                            replica, core)):
                        outfeed_devices.append(
                            tpu_ops.outfeed_dequeue_tuple(
                                tensor_types,
                                tensor_shapes,
                                device_ordinal=device_assignment.tpu_ordinal(
                                    replica, core)))
            offset = i * num_devices
            output_arrays = list(input_arrays)
            # Each output_array holds a different per-example tensor. We get results
            # for each tensor from each TPU for each TpuTrainStep call.
            for j in range(len(output_arrays)):
                for k in range(len(outfeed_devices)):
                    output_arrays[j] = output_arrays[j].write(
                        offset + k, outfeed_devices[k][j])

            return tuple([i + 1] + output_arrays)

        def LoopCond(i, *output_arrays):
            del output_arrays
            return i < num_loops

        output_arrays = []
        for i in range(len(tensor_shapes)):
            output_arrays.append(
                tf.TensorArray(tensor_types[i],
                               size=num_loops * num_devices,
                               element_shape=tensor_shapes[i]))
        # Loop once for each time that TpuTrainStep runs.
        output_arrays = tf.while_loop(LoopCond,
                                      LoopBody, [0] + output_arrays,
                                      parallel_iterations=1)[1:]
        concatenated_arrays = [array.concat() for array in output_arrays]
        return dict(zip(sorted(per_example_tensors), concatenated_arrays))
Ejemplo n.º 5
0
    def ComputePredictions(self,
                           encoder_outputs,
                           pronunciations,
                           is_inference=False):
        """Computes the predictions from the encoder_outputs, updating losses.

    Despite the name, this function does the bulk of the decoding and loss
    computation, incrementing the loss at each time step.

    Args:
      encoder_outputs: a NestedMap consisting of outputs of the
        FeatureNeighborhoodEncoder with  encoded - encoding of the input
        spelling
        neighbor_pronunciations_encoded - encodings of the neighbor prons
        neighbor_pronunciations_encoded - encodings of the neighbor spellings
        state - encoder state to which has been added dec_input - seed output
        for the decoder [*, 1] tensor consisting of sentence start indices
        (corresponding to "<s>")
      pronunciations: NestedMap with pronunciations - [*, max_pronunciation_len]
        tensor of pronunciations
      is_inference: If False then uses teacher forcing else does autoregression.

    Returns:
      NestedMap with loss, per_sequence_losses,labels, a
      [*, max_pronunciation_len] tensor of predictions, and attention
      ([*, max_pronunciation_len, max_spelling_len]), and
      neighbor_attention ([*, max_pronunciation_len, max_neighbors])
      tensors, along with the raw batch passed through from the encoder.
    """
        p = self.params
        targets = pronunciations.pronunciations
        t_len = int(targets.get_shape().as_list()[1])
        t_idx = tf.constant(0)
        attention = tf.TensorArray(dtype=tf.float32, size=t_len)
        neighbor_attention = tf.TensorArray(dtype=tf.float32, size=t_len)

        outputs = tf.TensorArray(dtype=tf.float32, size=t_len)

        loop_cond = lambda t_idx, ts, *_: tf.less(t_idx, t_len)

        dec_input = tf.convert_to_tensor([p.start] * p.input.batch_size)
        state = encoder_outputs.state

        # pylint: disable=missing-docstring
        def loop_body(t_idx, dec_input, attention, neighbor_attention, state,
                      outputs):
            decoder_result = self.Decode(encoder_outputs, dec_input, state)

            outputs = outputs.write(t_idx, decoder_result.predictions)
            attention = attention.write(t_idx,
                                        decoder_result.attention_weights)
            neighbor_attention = neighbor_attention.write(
                t_idx,
                tf.cast(decoder_result.neighbor_attention_weights,
                        dtype=tf.float32))

            if is_inference:
                dec_input = tf.cast(tf.argmax(decoder_result.predictions, 1),
                                    tf.int32)
            else:
                dec_input = targets[:, t_idx]
            t_idx = t_idx + 1
            state = decoder_result.state
            return t_idx, dec_input, attention, neighbor_attention, state, outputs

        _, _, attention, neighbor_attention, state, outputs = tf.while_loop(
            loop_cond,
            loop_body,
            loop_vars=[
                t_idx, dec_input, attention, neighbor_attention, state, outputs
            ])

        outputs = tf.transpose(outputs.stack(), [1, 0, 2])
        labels = tf.argmax(outputs, axis=-1)
        mask = tf.cast(tf.math.logical_not(tf.math.equal(targets, 0)),
                       dtype=tf.float32)
        loss = self._loss_object(targets, outputs, sample_weight=mask)
        loss = tf.reduce_sum(loss, axis=1)
        per_sequence_losses = (loss / t_len)
        loss = tf.reduce_mean(per_sequence_losses)
        predictions = py_utils.NestedMap()
        predictions.loss = loss
        predictions.per_sequence_losses = per_sequence_losses
        predictions.labels = labels
        predictions.attention = tf.transpose(tf.squeeze(attention.stack()),
                                             perm=[1, 0, 2])
        if p.use_neighbors:
            predictions.neighbor_attention = tf.transpose(tf.squeeze(
                neighbor_attention.stack()),
                                                          perm=[1, 0, 2])
        else:
            predictions.neighbor_attention = tf.squeeze(
                neighbor_attention.stack())
        # Expose this for subsequent data analysis
        predictions.batch = encoder_outputs.batch
        return predictions
Ejemplo n.º 6
0
    def _EncodeToIds(self, word):
        # Below:
        #   * a token is a wordpiece ID.
        #   * the tokens array will be merged in-place.
        #   * the candidates array is an array of size len(tokens) - 1.
        #     It contains the token for the merged wordpiece, if it exists,
        #     -1 otherwise. For instance, candidate[3] = id(token[3] + token[4]).
        # First, split into basic UTF-8 characters (letters).
        chars = tf.strings.unicode_split(word, 'UTF-8')
        tokens = self._StringToToken(chars)
        tokens = tf.where(
            tf.equal(tokens, NO_TOKEN),
            # Unseen character.
            tf.broadcast_to(self.unk_id, tf.shape(tokens)),
            tokens)
        # Create initial candidate list.
        candidates = tf.map_fn(self._MergeTokens, (tokens[:-1], tokens[1:]),
                               dtype=tokens.dtype)

        def _ShouldMerge(unused_tokens, candidates):
            """Merge until not possible, or we abort early according to merge_prob."""
            return tf.logical_and(
                tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)),
                tf.random.uniform([]) < self._merge_prob)

        def _MergeOneToken(tokens, i):
            return tf.expand_dims(self._MergeTokens(
                (tokens[i], tokens[i + 1])),
                                  axis=-1)

        def _MergeCandidates(tokens, candidates):
            """Merge in the reverse binary tree."""
            best_id = tf.argmin(candidates, output_type=tf.int32)
            # Perform the merge at position best_id.
            tokens = tf.concat([
                tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:]
            ],
                               axis=0)
            # Recompute the merge candidates.
            # Only the neighbors of best_id need to be recomputed.
            empty = tf.zeros([0], dtype=candidates.dtype)

            def _MergeLeft():
                return tf.concat([
                    candidates[:best_id - 1],
                    _MergeOneToken(tokens, best_id - 1)
                ],
                                 axis=0)

            left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty,
                                      _MergeLeft)

            def _MergeRight():
                return tf.concat([
                    _MergeOneToken(tokens, best_id), candidates[best_id + 2:]
                ],
                                 axis=0)

            right_candidates = tf.cond(
                tf.greater_equal(best_id,
                                 tf.size(tokens) - 1), lambda: empty,
                _MergeRight)

            candidates = tf.concat([left_candidates, right_candidates], axis=0)
            return tokens, candidates

        return tf.while_loop(_ShouldMerge,
                             _MergeCandidates, (tokens, candidates),
                             parallel_iterations=1,
                             back_prop=False)[0]
Ejemplo n.º 7
0
  def _StringsToIdsImpl(self, strs, max_length, append_eos, languages):
    """Takes a tensor of strings and returns id/padding tensors.

    This generates `token_ids`, `target_ids`, and `paddings` in the format that
    is expected for tokenizers. This performs padding to a fixed length and
    appends the end-of-sentence token as appropriate.

    Args:
      strs: a string Tensor.
      max_length: a python integer. The second dimension of the returned arrays.
        All sequences are padded or truncated to that length.
      append_eos: a python bool. See `BaseTokenizer` for explanation.
      languages: A vector of strings with the same length as `strs`.

    Returns:
      token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences
        always end with EOS unless the sequence exceeds the maximum length.
        Always padded with EOS.
      target_ids: a tensor of sequences of WPM ids not starting with SOS
        but ending with EOS. Always padded with EOS.
      paddings: a tensor of floats indicating, at each position, whether
        the corresponding position is padded.
    """
    p = self.params
    if append_eos is None:
      append_eos = p.append_eos

    batch_size = py_utils.GetShape(strs)[0]
    token_ids_ta = tf.TensorArray(tf.int32, batch_size)
    target_ids_ta = tf.TensorArray(tf.int32, batch_size)
    paddings_ta = tf.TensorArray(tf.float32, batch_size)

    def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta):
      """Tokenizes a single sentence."""
      ids, _ = self._wpm_encoder.Encode(strs[i])

      if append_eos:
        ids = tf.concat([ids, [self.eos_id]], axis=0)

      # This truncates after the eos is added, so some sentences might
      # not have </s> at the end.
      token_ids_ta = token_ids_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.concat([[self.sos_id], ids], axis=0), [max_length],
              self.eos_id))
      target_ids_ta = target_ids_ta.write(
          i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id))
      paddings_ta = paddings_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.))

      return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta

    _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop(
        lambda i, *_: i < batch_size,
        _TokenizeOneSentence,
        loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta,
                   paddings_ta),
        parallel_iterations=30,
        back_prop=False)

    token_ids = token_ids_ta.stack()
    target_ids = target_ids_ta.stack()
    paddings = paddings_ta.stack()

    if not p.pad_to_max_length:
      maxlen = tf.to_int32(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1)))
      token_ids = token_ids[:, :maxlen]
      target_ids = target_ids[:, :maxlen]
      paddings = paddings[:, :maxlen]

    return token_ids, target_ids, paddings
Ejemplo n.º 8
0
def FarthestPointSampler(points,
                         padding,
                         num_sampled_points,
                         precomputed_squared_distance=None,
                         num_seeded_points=0,
                         random_seed=None):
  """Samples num_sampled_points from points using farthest point sampling.

  Algorithm:
  1. Start by selecting a random point and adding to a selected set.
  2. For all remaining points, find the furthest point from those selected.
  3. Add furthest point to selected.
  4. Repeat 2-3 until num_sampled_points are selected.

  More details at https://en.wikipedia.org/wiki/Farthest-first_traversal

  This output of this function can be used with tf.batch_gather to extract the
  desired points, for example: tf.batch_gather(points, sampled_idx)

  Args:
    points: floating point tf.Tensor of shape [N, P1, dims]
    padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is
      real, and 1 otherwise.
    num_sampled_points: integer number of points to sample.
    precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of
      distances between each point. if None, distances will be computed on the
      fly.
    num_seeded_points: If num_seeded_points > 0, then the first
      num_seeded_points in points are considered to be seeded in the FPS
      sampling. Note that we assume that these points are *not* padded, and do
      not check padding when seeding them.
    random_seed: optional integer random seed to use with all the random ops.

  Returns:
    A tuple of tf.Tensors (sampled_idx, closest_idx) of types
    (tf.int32, tf.int32).

    sampled_idx is of shape [N, num_sampled_points] representing the indices
    selected using the sampler. This will have range of [0, P1].

    closest_idx is of shape [N, P1] representing the indices of the closest
    sampled points for each input point. closest_idx is used in PCNN as part of
    the pooling operation: each point is assigned to the closest sampled point
    and a max is taken over them. This will have a range of [0, P2] with the
    index of the closest sampled point that remains.
  """
  points = py_utils.HasRank(points, 3)
  batch_size, num_points, dims = py_utils.GetShape(points, 3)

  points = py_utils.with_dependencies(
      [py_utils.assert_greater_equal(num_points, num_sampled_points)], points)

  # Add a tiny bit of noise to the distance matrix or points so all
  # points are unique. This will also ensure true repeated points
  # like padded points are only selected after all valid points are selected.
  if precomputed_squared_distance is not None:
    precomputed_squared_distance = py_utils.HasShape(
        precomputed_squared_distance, [batch_size, num_points, num_points])
    precomputed_squared_distance += tf.random.uniform(
        (batch_size, num_points, 1),
        minval=1e-6,
        maxval=1e-5,
        dtype=tf.float32,
        seed=random_seed)
  else:
    points += tf.random.uniform((batch_size, num_points, dims),
                                minval=1e-6,
                                maxval=1e-5,
                                dtype=tf.float32,
                                seed=random_seed)

  # TensorArray to store the sampled indices in the loop.
  sampled_idx = tf.TensorArray(tf.int32, num_sampled_points)

  # Initialize distance_to_selected to inf for all points.
  distance_to_selected = float('inf') * tf.ones((batch_size, num_points))

  # For tracking the index to the closest selected point.
  closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32)

  # Current loop index counter.
  curr_idx = tf.constant(0, dtype=tf.int32)

  # Get number of valid points (1 is padded, so num_points - num_padded).
  num_valid_points = tf.cast(
      tf.cast(num_points, dtype=tf.float32) - tf.reduce_sum(padding, axis=1),
      dtype=tf.int32)

  def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
    """Loop body for farthest point sampler."""

    def _GetRandomRealPoint():
      """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
      random_values = tf.random.uniform((batch_size, num_points),
                                        minval=0,
                                        maxval=1,
                                        dtype=tf.float32,
                                        seed=random_seed)
      random_values = tf.where(
          tf.equal(padding, 0.0), random_values, padding * 10)
      return tf.argmin(random_values, axis=1, output_type=tf.int32)

    def _GetFurthestPoint():
      """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
      # Set padded points distance to negative so they aren't selected.
      padding_masked_distance_to_selected = tf.where(
          tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
              (batch_size, num_points), dtype=tf.float32))
      # But only do this when we still have valid points left.
      padding_masked_distance_to_selected = tf.where(
          tf.less(curr_idx, num_valid_points),
          padding_masked_distance_to_selected, distance_to_selected)
      return tf.argmax(
          padding_masked_distance_to_selected, axis=-1, output_type=tf.int32)

    def _GetSeededPoint():
      """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
      return tf.ones((batch_size,), dtype=tf.int32) * curr_idx

    # Select indices for this loop iteration.
    def _Seeded():
      return tf.cond(
          tf.less(curr_idx, num_seeded_points), _GetSeededPoint,
          _GetFurthestPoint)

    def _Real():
      return tf.cond(
          tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint)

    new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real)
    sampled_idx = sampled_idx.write(curr_idx, new_selected)

    # Extract the distance to the latest point selected to update
    # distance_to_selected.
    new_selected_gather_idx = tf.stack([tf.range(batch_size), new_selected],
                                       axis=1)
    if precomputed_squared_distance is not None:
      new_distance = tf.gather_nd(precomputed_squared_distance,
                                  new_selected_gather_idx)
    else:
      new_points = tf.reshape(
          tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims])
      new_distance = tf.reshape(
          SquaredDistanceMatrix(points, new_points), [batch_size, num_points])

    is_newly_closest = tf.less(new_distance, distance_to_selected)
    distance_to_selected = tf.minimum(distance_to_selected, new_distance)

    # Track the index to the closest selected point.
    new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
    closest_idx = tf.cond(
        tf.equal(curr_idx, 0),
        # At the first loop iteration, the init points are the closest.
        lambda: new_selected_tiled,
        # Otherwise, update with the new points based on the distances.
        lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx))
    return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx

  _, _, sampled_idx, closest_idx = tf.while_loop(
      lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points),
      _BodyFn,
      loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx),
      back_prop=False,
      maximum_iterations=num_sampled_points)

  sampled_idx = sampled_idx.stack()  # num_sampled_points x n
  sampled_idx = tf.transpose(sampled_idx, [1, 0])

  if isinstance(batch_size, int) and isinstance(num_sampled_points, int):
    sampled_idx.set_shape((batch_size, num_sampled_points))

  return sampled_idx, closest_idx
Ejemplo n.º 9
0
    def _StringsToIdsImpl(self, strs, max_length, append_eos, languages):
        del languages
        p = self.params
        if append_eos is None:
            append_eos = p.append_eos

        batch_size = py_utils.GetShape(strs)[0]
        token_ids_ta = tf.TensorArray(tf.int32, batch_size)
        target_ids_ta = tf.TensorArray(tf.int32, batch_size)
        paddings_ta = tf.TensorArray(tf.float32, batch_size)

        def _TokenizeOneSentence(i, text, token_ids_ta, target_ids_ta,
                                 paddings_ta):
            """Tokenizes a single sentence."""
            if tf.is_tensor(i):
                text_i = tf.gather(text, i)
            else:
                text_i = text[i]
            ids = self._tokenizer.tokenize(text_i).merge_dims(0, -1)
            ids.set_shape([None])

            if append_eos:
                ids = tf.concat([ids, [self.eos_id]], axis=0)
            sos_ids = tf.concat([[self.sos_id], ids], axis=0)
            if p.prepend_sos:
                ids = sos_ids

            # This truncates after the EOS is added, so some sentences might
            # not have EOS at the end.
            token_ids_ta = token_ids_ta.write(
                i, py_utils.PadOrTrimTo(sos_ids, [max_length], 0))
            target_ids_ta = target_ids_ta.write(
                i, py_utils.PadOrTrimTo(ids, [max_length], 0))
            paddings_ta = paddings_ta.write(
                i,
                py_utils.PadOrTrimTo(tf.zeros_like(ids, dtype=tf.float32),
                                     [max_length], 1.))

            return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta

        _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop(
            lambda i, *_: i < batch_size,
            _TokenizeOneSentence,
            loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta,
                       target_ids_ta, paddings_ta),
            parallel_iterations=30,
            back_prop=False)

        token_ids = token_ids_ta.stack()
        target_ids = target_ids_ta.stack()
        paddings = paddings_ta.stack()

        if not p.pad_to_max_length:
            maxlen = tf.cast(
                tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))),
                tf.int32)
            token_ids = token_ids[:, :maxlen]
            target_ids = target_ids[:, :maxlen]
            paddings = paddings[:, :maxlen]

        return token_ids, target_ids, paddings