예제 #1
0
    def get_mask_values(self, masked_lm_ids):
        """Get the values used for masking, random injection or no-op.

    Args:
      masked_lm_ids: a `RaggedTensor` of n dimensions and dtype int32 or int64
        whose values are the ids of items that have been selected for masking.
    Returns:
      a `RaggedTensor` of the same dtype and shape with `masked_lm_ids` whose
      values contain either the mask token, randomly injected token or original
      value.
    """
        validate_rates = control_flow_ops.Assert(
            self._mask_token_rate + self._random_token_rate <= 1,
            ["mask_token_rate + random_token_rate must be <= 1"])
        with ops.control_dependencies([validate_rates]):

            # Generate a random number for all mask-able items. Items that should be
            # treated atomically (e.g. all wordpieces in a token, span, etc) will have
            # the same random number.
            random_uniform = _get_random(masked_lm_ids)

            # Merge down to rank 2.
            random_uniform = (random_uniform if random_uniform.ragged_rank == 1
                              else random_uniform.merge_dims(1, -1))
            mask_values = masked_lm_ids

            all_mask_flat = array_ops.tile([self._mask_token],
                                           array_ops.shape(
                                               mask_values.flat_values))

            # Maybe add mask token `mask_token_rate`% of the time
            should_mask_flat = random_uniform.flat_values < math_ops.cast(
                self._mask_token_rate, dtypes.float32)
            mask_values = mask_values.with_flat_values(
                ragged_where_op.where(should_mask_flat,
                                      x=math_ops.cast(
                                          all_mask_flat,
                                          mask_values.flat_values.dtype),
                                      y=mask_values.flat_values))

            # Maybe inject random token `random_token_rate`% of the time.
            all_random_flat = random_ops.random_uniform(
                array_ops.shape(mask_values.flat_values),
                maxval=math_ops.cast(self._vocab_size, dtypes.float32))
            should_inject_random_flat = math_ops.logical_and(
                random_uniform.flat_values > self._mask_token_rate,
                random_uniform.flat_values < self._random_token_rate)
            mask_values = mask_values.with_flat_values(
                ragged_where_op.where(should_inject_random_flat,
                                      x=math_ops.cast(
                                          all_random_flat,
                                          mask_values.flat_values.dtype),
                                      y=mask_values.flat_values))
            return mask_values
예제 #2
0
def _multivalent_span_alignment(overlaps):
    """Returns the multivalent span alignment for a given overlaps tensor.

  Args:
    overlaps: `<int64>[D1...DB, source_size, target_size]`: `overlaps[b1...bB,
      i, j]` is true if source span `i` overlaps target span `j` (in batch
      `b1...bB`).

  Returns:
    `<int64>[D1...DB, source_size, (num_aligned_spans)]`:
    `result[b1...bB, i, n]=j` if target span `j` is the `n`'th target span
    that aligns with source span `i` (in batch `b1...bB`).
  """
    overlaps_ndims = overlaps.shape.ndims
    assert overlaps_ndims is not None  # guaranteed/checked by span_overlaps()
    assert overlaps_ndims >= 2

    # If there are multiple batch dimensions, then flatten them and recurse.
    if overlaps_ndims > 3:
        if not isinstance(overlaps, ragged_tensor.RaggedTensor):
            overlaps = ragged_tensor.RaggedTensor.from_tensor(
                overlaps, ragged_rank=overlaps.shape.ndims - 3)
        return overlaps.with_values(
            _multivalent_span_alignment(overlaps.values))

    elif overlaps_ndims == 2:  # no batch dimension
        assert not isinstance(overlaps, ragged_tensor.RaggedTensor)
        overlap_positions = array_ops.where(overlaps)
        return ragged_tensor.RaggedTensor.from_value_rowids(
            values=overlap_positions[:, 1],
            value_rowids=overlap_positions[:, 0],
            nrows=array_ops.shape(overlaps, out_type=dtypes.int64)[0])

    else:  # batch dimension
        if not isinstance(overlaps, ragged_tensor.RaggedTensor):
            overlaps = ragged_tensor.RaggedTensor.from_tensor(overlaps,
                                                              ragged_rank=1)
        overlap_positions = ragged_where_op.where(overlaps.values)
        if isinstance(overlaps.values, ragged_tensor.RaggedTensor):
            overlaps_values_nrows = overlaps.values.nrows()
        else:
            overlaps_values_nrows = array_ops.shape(overlaps.values,
                                                    out_type=dtypes.int64)[0]
        return overlaps.with_values(
            ragged_tensor.RaggedTensor.from_value_rowids(
                values=overlap_positions[:, 1],
                value_rowids=overlap_positions[:, 0],
                nrows=overlaps_values_nrows))
예제 #3
0
    def get_segments(self, sentences):
        """Extracts the next sentence label from sentences.

    Args:
      sentences: A `RaggedTensor` of strings w/ shape [batch, (num_sentences)].

    Returns:
      A tuple of (segment_a, segment_b, is_next_sentence) where:

      segment_a: A `Tensor` of strings w/ shape [total_num_sentences] that
        contains all the original sentences.
      segment_b:  A `Tensor` with shape [num_sentences] that contains
        either the subsequent sentence of `segment_a` or a randomly injected
        sentence.
      is_next_sentence: A `Tensor` of bool w/ shape [num_sentences]
        that contains whether or not `segment_b` is truly a subsequent sentence
        or not.
    """
        next_sentence = ragged_map_ops.map_fn(
            functools.partial(manip_ops.roll, axis=0, shift=-1),
            sentences,
            dtype=ragged_tensor.RaggedTensorType(dtypes.string, 1),
            infer_shape=False)
        random_sentence = sentences.with_flat_values(
            self._shuffle_fn(sentences.flat_values))
        is_next_sentence_labels = (self._random_fn(sentences.flat_values.shape)
                                   > self._random_next_sentence_threshold)
        is_next_sentence = sentences.with_flat_values(is_next_sentence_labels)

        # Randomly decide if we should use next sentence or throw in a random
        # sentence.
        segment_two = ragged_where_op.where(is_next_sentence,
                                            x=next_sentence,
                                            y=random_sentence)

        # Get rid of the docs dimensions
        sentences = sentences.merge_dims(-2, -1)
        segment_two = segment_two.merge_dims(-2, -1)
        is_next_sentence = is_next_sentence.merge_dims(-2, -1)
        is_next_sentence = math_ops.cast(is_next_sentence, dtypes.int64)
        return sentences, segment_two, is_next_sentence
예제 #4
0
def batch_gather_with_default(params, indices, default_value='', name=None):
    """Same as `batch_gather` but inserts `default_value` for invalid indices.

  This operation is similar to `batch_gather` except that it will substitute
  the value for invalid indices with `default_value` as the contents.
  See `batch_gather` for more details.


  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    default_value: A value to be inserted in places where `indices` are out of
      bounds. Must be the same dtype as params and either a scalar or rank 1.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([
          ['a', 'b', 'c'],
          ['d'],
          [],
          ['e']])
    >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]])
    >>> batch_gather_with_default(params, indices, 'FOO')
    [['b', 'c', 'FOO'], [], [], ['e', 'FOO']]
  ```
  """
    with ops.name_scope(name, 'RaggedBatchGatherWithDefault'):
        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params,
            name='params',
        )
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices,
            name='indices',
        )
        default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            default_value,
            name='default_value',
        )
        # TODO(hterry): lift this restriction and support default_values of
        #               of rank > 1
        if (default_value.shape.ndims is not 0
                and default_value.shape.ndims is not 1):
            raise ValueError('"default_value" must be a scalar or vector')
        upper_bounds = None
        if indices.shape.ndims is None:
            raise ValueError('Indices must have a known rank.')
        if params.shape.ndims is None:
            raise ValueError('Params must have a known rank.')

        num_batch_dimensions = indices.shape.ndims - 1
        pad = None
        # The logic for this works as follows:
        # - create a padded params, where:
        #    padded_params[b1...bn, 0] = default_value
        #    padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0)
        # - create an `upper_bounds` Tensor that contains the number of elements
        #   in each innermost rank. Broadcast `upper_bounds` to be the same shape
        #   as `indices`.
        # - check to see which index in `indices` are out of bounds and substitute
        #   it with the index containing `default_value` (the first).
        # - call batch_gather with the indices adjusted.
        with ops.control_dependencies([
                check_ops.assert_greater_equal(array_ops.rank(params),
                                               array_ops.rank(indices))
        ]):
            if ragged_tensor.is_ragged(params):
                row_lengths = ragged_array_ops.expand_dims(
                    params.row_lengths(axis=num_batch_dimensions), axis=-1)
                upper_bounds = math_ops.cast(row_lengths, indices.dtype)

                pad_shape = _get_pad_shape(params, indices)

                pad = ragged_tensor_shape.broadcast_to(default_value,
                                                       pad_shape)
            else:
                params_shape = array_ops.shape(params)
                pad_shape = array_ops.concat([
                    params_shape[:num_batch_dimensions], [1],
                    params_shape[num_batch_dimensions + 1:params.shape.ndims]
                ], 0)
                upper_bounds = params_shape[num_batch_dimensions]
                pad = array_ops.broadcast_to(default_value, pad_shape)

            # Add `default_value` as the first value in the innermost (ragged) rank.
            pad = math_ops.cast(pad, params.dtype)
            padded_params = array_ops.concat([pad, params],
                                             axis=num_batch_dimensions)

            # Adjust the indices by substituting out-of-bound indices to the
            # default-value index (which is the first element)
            shifted_indices = indices + 1
            is_out_of_bounds = (indices < 0) | (indices > upper_bounds)
            adjusted_indices = ragged_where_op.where(
                is_out_of_bounds,
                x=array_ops.zeros_like(indices),
                y=shifted_indices,
            )
            return array_ops.batch_gather(params=padded_params,
                                          indices=adjusted_indices,
                                          name=name)
예제 #5
0
 def testRaggedWhereErrors(self, condition, error, message, x=None, y=None):
   with self.assertRaisesRegex(error, message):
     ragged_where_op.where(condition, x, y)
예제 #6
0
 def testRaggedWhere(self, condition, expected, x=None, y=None):
   result = ragged_where_op.where(condition, x, y)
   self.assertAllEqual(result, expected)
def batch_gather_with_default(params,
                              indices,
                              default_value='',
                              name=None):
  """Same as `batch_gather` but inserts `default_value` for invalid indices.

  This operation is similar to `batch_gather` except that it will substitute
  the value for invalid indices with `default_value` as the contents.
  See `batch_gather` for more details.


  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    default_value: A value to be inserted in places where `indices` are out of
      bounds. Must be the same dtype as params and either a scalar or rank 1.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([
          ['a', 'b', 'c'],
          ['d'],
          [],
          ['e']])
    >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]])
    >>> batch_gather_with_default(params, indices, 'FOO')
    [['b', 'c', 'FOO'], [], [], ['e', 'FOO']]
  ```
  """
  with ops.name_scope(name, 'RaggedBatchGatherWithDefault'):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params',
    )
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices',
    )
    default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        default_value, name='default_value',
    )
    # TODO(hterry): lift this restriction and support default_values of
    #               of rank > 1
    if (default_value.shape.ndims is not 0
        and default_value.shape.ndims is not 1):
      raise ValueError('"default_value" must be a scalar or vector')
    upper_bounds = None
    if indices.shape.ndims is None:
      raise ValueError('Indices must have a known rank.')
    if params.shape.ndims is None:
      raise ValueError('Params must have a known rank.')

    num_batch_dimensions = indices.shape.ndims - 1
    pad = None
    # The logic for this works as follows:
    # - create a padded params, where:
    #    padded_params[b1...bn, 0] = default_value
    #    padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0)
    # - create an `upper_bounds` Tensor that contains the number of elements
    #   in each innermost rank. Broadcast `upper_bounds` to be the same shape
    #   as `indices`.
    # - check to see which index in `indices` are out of bounds and substitute
    #   it with the index containing `default_value` (the first).
    # - call batch_gather with the indices adjusted.
    with ops.control_dependencies([
        check_ops.assert_greater_equal(array_ops.rank(params),
                                       array_ops.rank(indices))]):
      if ragged_tensor.is_ragged(params):
        row_lengths = ragged_array_ops.expand_dims(
            params.row_lengths(axis=num_batch_dimensions),
            axis=-1)
        upper_bounds = math_ops.cast(row_lengths, indices.dtype)

        pad_shape = _get_pad_shape(params, indices)

        pad = ragged_tensor_shape.broadcast_to(
            default_value, pad_shape)
      else:
        params_shape = array_ops.shape(params)
        pad_shape = array_ops.concat([
            params_shape[:num_batch_dimensions],
            [1],
            params_shape[num_batch_dimensions + 1:params.shape.ndims]
        ], 0)
        upper_bounds = params_shape[num_batch_dimensions]
        pad = array_ops.broadcast_to(default_value, pad_shape)

      # Add `default_value` as the first value in the innermost (ragged) rank.
      pad = math_ops.cast(pad, params.dtype)
      padded_params = array_ops.concat(
          [pad, params], axis=num_batch_dimensions)

      # Adjust the indices by substituting out-of-bound indices to the
      # default-value index (which is the first element)
      shifted_indices = indices + 1
      is_out_of_bounds = (indices < 0) | (indices > upper_bounds)
      adjusted_indices = ragged_where_op.where(
          is_out_of_bounds,
          x=array_ops.zeros_like(indices), y=shifted_indices,
      )
      return array_ops.batch_gather(
          params=padded_params, indices=adjusted_indices, name=name)
예제 #8
0
 def testRaggedWhereErrors(self, condition, error, message, x=None, y=None):
   with self.assertRaisesRegexp(error, message):
     ragged_where_op.where(condition, x, y)
예제 #9
0
 def testRaggedWhere(self, condition, expected, x=None, y=None):
   result = ragged_where_op.where(condition, x, y)
   self.assertRaggedEqual(result, expected)