Exemplo n.º 1
0
    def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
        """Computes mean and variance over the valid data points in inputs."""
        inputs = py_utils.with_dependencies([
            py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
            py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
        ], inputs)
        rank = tf.rank(mask)
        reduce_over_dims = tf.range(0, rank - 1)
        sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                              reduce_over_dims)
        count_v = tf.reduce_sum(mask, reduce_over_dims)
        # Input shape is guaranteed to be a multiple of mask shape because the
        # inputs * mask op above was successfully broadcasted.
        mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
        count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_v = tf.tpu.cross_replica_sum(sum_v)
            count_v = tf.tpu.cross_replica_sum(count_v)

        count_v = tf.maximum(count_v, 1.0)
        mean = sum_v / count_v
        sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                               reduce_over_dims)

        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_vv = tf.tpu.cross_replica_sum(sum_vv)

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
        ], sum_vv / count_v)
        return mean, variance
Exemplo n.º 2
0
def ComputeMoments(inputs,
                   padding,
                   reduce_over_dims,
                   cumulative_axis=None,
                   enable_cross_replica_sum_on_tpu=False,
                   keepdims=False):
    """Computes mean and variance over the valid data points in inputs."""
    mask = 1.0 - padding
    inputs = py_utils.with_dependencies([
        py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
        py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
    ], inputs)
    sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                          reduce_over_dims,
                          keepdims=keepdims)
    count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims)

    if cumulative_axis is not None:
        sum_v = tf.math.cumsum(sum_v, axis=cumulative_axis)
        count_v = tf.math.cumsum(count_v, axis=cumulative_axis)
    # Input shape is guaranteed to be a multiple of mask shape because the
    # inputs * mask op above was successfully broadcasted.
    input_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(inputs), reduce_over_dims))
    mask_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(mask), reduce_over_dims))
    mask_multiplier = tf.math.truediv(input_size_on_reduced_dims,
                                      mask_size_on_reduced_dims)
    count_v *= tf.cast(mask_multiplier, count_v.dtype)
    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_v = tf.tpu.cross_replica_sum(sum_v)
        count_v = tf.tpu.cross_replica_sum(count_v)

    count_v = tf.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                           reduce_over_dims,
                           keepdims=keepdims)
    if cumulative_axis is not None:
        sum_vv = tf.math.cumsum(sum_vv, axis=cumulative_axis)

    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_vv = tf.tpu.cross_replica_sum(sum_vv)

    variance = py_utils.with_dependencies([
        py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
    ], sum_vv / count_v)
    return mean, variance
Exemplo n.º 3
0
 def _ComputeBN(self, inputs, paddings, gamma, beta, norm_mean,
                norm_variance):
     p = self.params
     with tf.control_dependencies([
             py_utils.assert_greater_equal(norm_variance,
                                           tf.zeros_like(norm_variance)),
             py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                         tf.shape(norm_mean)),
             py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                         tf.shape(norm_variance)),
     ]):
         if p.use_fused_batch_norm_for_eval and (self.do_eval
                                                 or p.freeze_bn_stats):
             bn_output, _, _ = nn.fused_batch_norm(inputs,
                                                   gamma,
                                                   beta,
                                                   norm_mean,
                                                   norm_variance,
                                                   self._epsilon,
                                                   is_training=False)
         else:
             bn_output = tf.nn.batch_normalization(inputs, norm_mean,
                                                   norm_variance, beta,
                                                   gamma, self._epsilon)
         if p.set_padded_output_to_zero:
             bn_output = py_utils.ApplyPadding(paddings, bn_output)
     return bn_output
Exemplo n.º 4
0
    def FProp(self, theta, inputs, paddings, class_emb):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [batch, ..., dim].
      paddings: The paddings tensor.  Shaped [batch, ..., 1], with the same rank
        as the input tensor.
      class_emb: The conditioning inputs, Shaped [batch, emb_dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        p = self.params
        batch = py_utils.GetShape(inputs)[0]
        class_emb = py_utils.HasShape(class_emb, [batch, p.class_emb_dim])
        if not py_utils.use_tpu():
            class_emb = py_utils.with_dependencies([
                py_utils.assert_less_equal(
                    tf.cast(class_emb, tf.int32), 1, name='one_hot_assert1'),
                py_utils.assert_greater_equal(
                    tf.cast(class_emb, tf.int32), 0, name='one_hot_assert2'),
                py_utils.assert_equal(tf.ones([batch], tf.int32),
                                      tf.cast(tf.reduce_sum(class_emb, -1),
                                              tf.int32),
                                      name='one_hot_assert3'),
            ], class_emb)

        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings=paddings, class_emb=class_emb)
            return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean,
                                   norm_variance)
Exemplo n.º 5
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings)
            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.zeros_like(norm_variance)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_variance)),
            ]):
                bn_output = tf.nn.batch_normalization(inputs, norm_mean,
                                                      norm_variance, beta,
                                                      gamma, self._epsilon)
            bn_output *= 1.0 - paddings
            return bn_output
Exemplo n.º 6
0
    def _Normalize(self, theta, grouped_inputs, group_mean, group_variance):
        p = self.params
        group_mean = py_utils.CheckNumerics(
            group_mean, f'mean of {p.name} failed numeric check.')
        group_variance = py_utils.CheckNumerics(
            group_variance, f'variance of {p.name} failed numeric check.')

        input_shape = py_utils.GetShape(grouped_inputs)
        moment_shape = list(input_shape)
        if p.input_rank == 4:
            moment_shape[2] = 1
            moment_shape[-1] = 1
        else:
            moment_shape[-1] = 1
        if not p.cumulative:
            # If not cumulative, the seqlen dimension is also reduced.
            moment_shape[1] = 1

        group_mean = py_utils.HasShape(group_mean, moment_shape)
        group_variance = py_utils.HasShape(group_variance, moment_shape)
        group_variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(group_variance,
                                          tf.cast(0, group_variance.dtype))
        ], group_variance)

        grouped_inputs = (grouped_inputs - group_mean
                          ) * tf.math.rsqrt(group_variance + self._epsilon)
        # Merges the last two dims.
        grouped_inputs = tf.reshape(grouped_inputs, input_shape[:-2] + [-1])

        # Note, The real gamma to use is 1 + gamma.
        outputs = grouped_inputs * (theta.gamma + 1) + theta.beta
        return outputs
Exemplo n.º 7
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply group normalization.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor with shape [batch_size, height, width, channel].
      paddings: The paddings tensor with shape [batch_size, height]. Intended to
        be used for sequence processing where `height` is `time`.

    Returns:
      A single tensor as the output after applying group normalization, with
      the same shape as 'inputs'. Or a output, output_paddings pair if input
      paddings is not None.
    """
        p = self.params
        n, h, w, c = tf.unstack(tf.shape(inputs), axis=0, num=4)
        group_size = p.dim // p.num_groups
        num_groups = p.num_groups
        min_group_size = p.min_group_size if p.dim > p.min_group_size else p.dim
        if group_size <= min_group_size:
            group_size = min_group_size
            num_groups = p.dim // group_size

        with tf.name_scope(p.name):
            x = tf.reshape(inputs, [n, h, w, num_groups, group_size])
            if paddings is None:
                counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics(
                    x, axes=[1, 2, 4], keepdims=True)
                norm_mean, norm_variance = tf.nn.normalize_moments(
                    counts, means_ss, variance_ss, None)
            else:
                expanded_paddings = tf.reshape(paddings, [n, h, 1, 1, 1])
                norm_mean, norm_variance = ComputeMomentsWithPadding(
                    x, expanded_paddings, [1, 2, 4], keepdims=True)

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta = theta.beta
            gamma = theta.gamma

            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.cast(0., norm_variance.dtype)),
                    py_utils.assert_shape_match([n, 1, 1, num_groups, 1],
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match([n, 1, 1, num_groups, 1],
                                                tf.shape(norm_variance)),
            ]):
                x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon)
                x = tf.reshape(x, [n, h, w, c])
                gn_output = x * gamma + beta
                gn_output = tf.reshape(gn_output, [n, h, w, c])
                if paddings is None:
                    return gn_output
                else:
                    return gn_output, paddings
Exemplo n.º 8
0
def SplitTensors(xs, num_splits):
  """Splits tensors in `xs` evenly into num_splits along the 1st dimenion.

  Args:
    xs: A tuple of tensors. Each tensor's 1st dimension is the same size.
    num_splits: A python integer.

  Returns:
    A tuple of lists of tensors, num elements in the tuple = len(xs).

    i-th element in each list corresponds to i-th split of each tensor in xs
    along the first dimension of each tensor.
  """
  # assert first dim of all tensors in xs is equal
  batch_dims = [tf.shape(x)[0] for x in xs]
  all_batch_dims = tf.stack(batch_dims)

  all_batch_dims = py_utils.with_dependencies([
      py_utils.assert_equal(
          all_batch_dims,
          tf.shape(xs[0])[0],
          message='first dim of tensors in xs must match'),
      py_utils.assert_greater_equal(
          tf.shape(xs[0])[0],
          num_splits,
          message='first dim of tensors in xs must be greater than num_splits')
  ], all_batch_dims)

  splits = ComputeSplits(tf.shape(xs[0])[0], num_splits)
  # add the above assertion into the compute graph
  splits = py_utils.with_dependencies([all_batch_dims], splits)
  split_xs = [tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs]

  return split_xs
Exemplo n.º 9
0
def MakeCausalPadding(seq_len,
                      block_size,
                      left_context,
                      right_context,
                      dtype=tf.float32):
    """Makes the causal padding tensor for a full sequence.

  Args:
    seq_len: int or scalar int tensor. Sequence length.
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size.
    right_context: int. Right context size.
    dtype: tf.dtype, default is tf.float32.

  Returns:
    A tensor of [num_blocks, block_size, context_size] taking values in {0, 1},
    where context_size = block_size + (left_context - 1) + right_context.
    Element b, i, j is zero if in the b-th block, the i-th frame can access
    the j-th frame in the context.
  """
    seq_len = py_utils.with_dependencies([
        py_utils.assert_greater_equal(
            seq_len, 1, message='seq_len must be at least 1')
    ], seq_len)

    num_blocks = (seq_len + block_size - 1) // block_size
    context_size = block_size + (left_context - 1) + right_context

    # [num_blocks, block_size]: source positions in the original sequence.
    src_positions = tf.reshape(tf.range(num_blocks * block_size),
                               [num_blocks, block_size])
    # [num_blocks,]: source positions at the start of each block.
    block_start_positions = tf.range(0, num_blocks * block_size, block_size)
    # [context_size]:  positions relative to the block start.
    relative_context_positions = tf.range(context_size) - (left_context - 1)

    # [num_blocks, context_size]: target positions in the original sequence.
    tgt_positions = (block_start_positions[:, tf.newaxis] +
                     relative_context_positions[tf.newaxis, :])
    # [num_blocks, block_size, context_size]: position differences between source-
    # target pairs.
    position_diff = src_positions[:, :,
                                  tf.newaxis] - tgt_positions[:, tf.newaxis, :]
    # [num_blocks, block_size, context_size]: if attention is allowed between
    # source-target pairs.
    valid_atten = tf.math.logical_and(-right_context <= position_diff,
                                      position_diff < left_context)

    # [num_blocks, block_size]: if the source position is valid, not padded.
    valid_src = src_positions < seq_len
    # [num_blocks, context_size]: if the target position is valid, not padded.
    valid_tgt = tf.math.logical_and(0 <= tgt_positions,
                                    tgt_positions < seq_len)

    valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis],
                                       valid_tgt[:, tf.newaxis, :])

    padding = 1.0 - tf.cast(valid_atten, dtype=dtype)

    return padding
Exemplo n.º 10
0
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,
                                            local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent
Exemplo n.º 11
0
    def _MaybeExpandPaddings(self, inputs, paddings):
        # rank difference is at most one.
        rank_diff = tf.rank(inputs) - tf.rank(paddings)
        paddings = py_utils.with_dependencies([
            py_utils.assert_less_equal(rank_diff, 1),
            py_utils.assert_greater_equal(rank_diff, 0)
        ], paddings)

        # Pads [1] to the end of paddings.
        paddings = tf.reshape(
            paddings,
            tf.concat([tf.shape(paddings),
                       tf.tile([1], [rank_diff])], axis=0))
        return paddings
Exemplo n.º 12
0
                def ApplyBias():
                    """Bias and update log_probs and consistent."""

                    # Consistent if step_ids == labels from previous step
                    # TODO(navari): Consider updating consistent only if weights > 0. Then
                    # re-evaluate the need for bias_only_if_consistent=True.
                    # Note that prev_label is incorrrect for step 0 but is overridden
                    # later
                    prev_label = TileForBeamAndFlatten(
                        tf.gather(labels, tf.maximum(time_step - 1, 0),
                                  axis=1))
                    is_step0 = tf.equal(time_step, 0)
                    local_consistence = tf.math.logical_or(
                        is_step0, tf.equal(prev_label, tf.squeeze(step_ids,
                                                                  1)))
                    consistent = tf.math.logical_and(states.consistent,
                                                     local_consistence)

                    # get label, weight slices corresponding to current time_step
                    label = TileForBeamAndFlatten(
                        tf.gather(labels, time_step, axis=1))
                    weight = TileForBeamAndFlatten(
                        tf.gather(weights, time_step, axis=1))
                    if p.bias_only_if_consistent:
                        weight = weight * tf.cast(consistent,
                                                  py_utils.FPropDtype(p))

                    # convert from dense label to sparse label probs
                    vocab_size = tf.shape(bs_results.log_probs)[1]
                    label_probs = tf.one_hot(label,
                                             vocab_size,
                                             dtype=py_utils.FPropDtype(
                                                 p))  # [tgt_batch, vocab_size]
                    pred_probs = tf.exp(bs_results.log_probs)

                    # interpolate predicted probs and label probs
                    weight = tf.expand_dims(weight, 1)
                    probs = py_utils.with_dependencies([
                        py_utils.assert_less_equal(weight, 1.),
                        py_utils.assert_greater_equal(weight, 0.)
                    ], (1.0 - weight) * pred_probs + weight * label_probs)
                    # Ensure that tf.math.log is applied to positive values.
                    probs = tf.maximum(probs,
                                       tf.constant(1e-12, dtype=probs.dtype))
                    return tf.math.log(probs), consistent
Exemplo n.º 13
0
        def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states,
                                      num_hyps_per_beam, *args, **kwargs):
            """Wrapper for adding bias to _PreBeamSearchStateCallback.

      Biases results.log_probs towards provided encoder_outputs.targets.

      Args:
        theta: a NestedMap of parameters.
        encoder_outputs: a NestedMap computed by encoder.
        step_ids: A tensor of shape [tgt_batch, 1].
        states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
        num_hyps_per_beam: Beam size.
        *args: additional arguments to _PreBeamSearchStepCallback.
        **kwargs: additional arguments to _PreBeamSearchStepCallback.

      Returns:
        A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
          atten_probs:
            The updated attention probs, of shape [tgt_batch, src_len].
          log_probs:
            Log prob for each of the tokens in the target vocab. This is of
            shape
            [tgt_batch, vocab_size].
        out_states: a `.NestedMap` The updated states. The states relevant here
          are:
          time_step: A scalar indicating current step of decoder.  Must be
            provided and maintained by subclass.
          consistent: A boolean vector of shape [tgt_batch, ] which tracks
              whether each hypothesis has exactly matched
              encoder_outputs.targets
              so far.
      """
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)

            labels = encoder_outputs.targets.labels
            weights = encoder_outputs.targets.weights

            def TileForBeamAndFlatten(tensor):
                tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                tensor = tf.tile(
                    tensor,
                    [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
                tgt_batch = tf.shape(step_ids)[
                    0]  # num_hyps_per_beam*src_batch
                return tf.reshape(tensor, [tgt_batch])

            # Consistent if step_ids == labels from previous step
            # TODO(navari): Consider updating consistent only if weights > 0. Then
            # re-evaluate the need for bias_only_if_consistent=True.
            # Note that prev_label is incorrrect for step 0 but is overridden later
            prev_label = TileForBeamAndFlatten(
                tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
            is_step0 = tf.equal(time_step, 0)
            local_consistence = tf.logical_or(
                is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
            out_states.consistent = tf.logical_and(states.consistent,
                                                   local_consistence)

            # get label, weight slices corresponding to current time_step
            label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
            weight = TileForBeamAndFlatten(
                tf.gather(weights, time_step, axis=1))
            if p.bias_only_if_consistent:
                weight = weight * tf.cast(out_states.consistent, p.dtype)

            # convert from dense label to sparse label probs
            vocab_size = tf.shape(bs_results.log_probs)[1]
            uncertainty = tf.constant(
                1e-10,
                p.dtype)  # avoid 0 probs which may cause issues with log
            label_probs = tf.one_hot(label,
                                     vocab_size,
                                     on_value=1 - uncertainty,
                                     off_value=uncertainty /
                                     tf.cast(vocab_size - 1, p.dtype),
                                     dtype=p.dtype)  # [tgt_batch, vocab_size]
            pred_probs = tf.exp(bs_results.log_probs)

            # interpolate predicted probs and label probs
            weight = tf.expand_dims(weight, 1)
            probs = py_utils.with_dependencies([
                py_utils.assert_less_equal(weight, 1.),
                py_utils.assert_greater_equal(weight, 0.)
            ], (1.0 - weight) * pred_probs + weight * label_probs)

            bs_results.log_probs = tf.log(probs)

            return bs_results, out_states
Exemplo n.º 14
0
def FarthestPointSampler(points,
                         padding,
                         num_sampled_points,
                         precomputed_squared_distance=None,
                         num_seeded_points=0,
                         random_seed=None):
    """Samples num_sampled_points from points using farthest point sampling.

  Algorithm:
  1. Start by selecting a random point and adding to a selected set.
  2. For all remaining points, find the furthest point from those selected.
  3. Add furthest point to selected.
  4. Repeat 2-3 until num_sampled_points are selected.

  More details at https://en.wikipedia.org/wiki/Farthest-first_traversal

  This output of this function can be used with tf.batch_gather to extract the
  desired points, for example: tf.batch_gather(points, sampled_idx)

  Args:
    points: floating point tf.Tensor of shape [N, P1, dims]
    padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is
      real, and 1 otherwise.
    num_sampled_points: integer number of points to sample.
    precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of
      distances between each point. if None, distances will be computed on the
      fly.
    num_seeded_points: If num_seeded_points > 0, then the first
      num_seeded_points in points are considered to be seeded in the FPS
      sampling. Note that we assume that these points are *not* padded, and do
      not check padding when seeding them.
    random_seed: optional integer random seed to use with all the random ops.

  Returns:
    A tuple of tf.Tensors (sampled_idx, closest_idx) of types
    (tf.int32, tf.int32).

    sampled_idx is of shape [N, num_sampled_points] representing the indices
    selected using the sampler. This will have range of [0, P1].

    closest_idx is of shape [N, P1] representing the indices of the closest
    sampled points for each input point. closest_idx is used in PCNN as part of
    the pooling operation: each point is assigned to the closest sampled point
    and a max is taken over them. This will have a range of [0, P2] with the
    index of the closest sampled point that remains.
  """
    points = py_utils.HasRank(points, 3)
    batch_size, num_points, dims = py_utils.GetShape(points, 3)

    points = py_utils.with_dependencies(
        [py_utils.assert_greater_equal(num_points, num_sampled_points)],
        points)

    # Add a tiny bit of noise to the distance matrix or points so all
    # points are unique. This will also ensure true repeated points
    # like padded points are only selected after all valid points are selected.
    if precomputed_squared_distance is not None:
        precomputed_squared_distance = py_utils.HasShape(
            precomputed_squared_distance, [batch_size, num_points, num_points])
        precomputed_squared_distance += tf.random.uniform(
            (batch_size, num_points, 1),
            minval=1e-6,
            maxval=1e-5,
            dtype=tf.float32,
            seed=random_seed)
    else:
        points += tf.random.uniform((batch_size, num_points, dims),
                                    minval=1e-6,
                                    maxval=1e-5,
                                    dtype=tf.float32,
                                    seed=random_seed)

    # TensorArray to store the sampled indices in the loop.
    sampled_idx = tf.TensorArray(tf.int32, num_sampled_points)

    # Initialize distance_to_selected to inf for all points.
    distance_to_selected = float('inf') * tf.ones((batch_size, num_points))

    # For tracking the index to the closest selected point.
    closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32)

    # Current loop index counter.
    curr_idx = tf.constant(0, dtype=tf.int32)

    # Get number of valid points (1 is padded, so num_points - num_padded).
    num_valid_points = tf.cast(tf.cast(num_points, dtype=tf.float32) -
                               tf.reduce_sum(padding, axis=1),
                               dtype=tf.int32)

    def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
        """Loop body for farthest point sampler."""
        def _GetRandomRealPoint():
            """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
            random_values = tf.random.uniform((batch_size, num_points),
                                              minval=0,
                                              maxval=1,
                                              dtype=tf.float32,
                                              seed=random_seed)
            random_values = tf.where(tf.equal(padding, 0.0), random_values,
                                     padding * 10)
            return tf.argmin(random_values, axis=1, output_type=tf.int32)

        def _GetFurthestPoint():
            """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
            # Set padded points distance to negative so they aren't selected.
            padding_masked_distance_to_selected = tf.where(
                tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
                    (batch_size, num_points), dtype=tf.float32))
            # But only do this when we still have valid points left.
            padding_masked_distance_to_selected = tf.where(
                tf.less(curr_idx, num_valid_points),
                padding_masked_distance_to_selected, distance_to_selected)
            return tf.argmax(padding_masked_distance_to_selected,
                             axis=-1,
                             output_type=tf.int32)

        def _GetSeededPoint():
            """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
            return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx

        # Select indices for this loop iteration.
        def _Seeded():
            return tf.cond(tf.less(curr_idx, num_seeded_points),
                           _GetSeededPoint, _GetFurthestPoint)

        def _Real():
            return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint,
                           _GetFurthestPoint)

        new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded,
                               _Real)
        sampled_idx = sampled_idx.write(curr_idx, new_selected)

        # Extract the distance to the latest point selected to update
        # distance_to_selected.
        new_selected_gather_idx = tf.stack(
            [tf.range(batch_size), new_selected], axis=1)
        if precomputed_squared_distance is not None:
            new_distance = tf.gather_nd(precomputed_squared_distance,
                                        new_selected_gather_idx)
        else:
            new_points = tf.reshape(
                tf.gather_nd(points, new_selected_gather_idx),
                [batch_size, 1, dims])
            new_distance = tf.reshape(
                SquaredDistanceMatrix(points, new_points),
                [batch_size, num_points])

        is_newly_closest = tf.less(new_distance, distance_to_selected)
        distance_to_selected = tf.minimum(distance_to_selected, new_distance)

        # Track the index to the closest selected point.
        new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
        closest_idx = tf.cond(
            tf.equal(curr_idx, 0),
            # At the first loop iteration, the init points are the closest.
            lambda: new_selected_tiled,
            # Otherwise, update with the new points based on the distances.
            lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)
        )
        return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx

    _, _, sampled_idx, closest_idx = tf.while_loop(
        lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points),
        _BodyFn,
        loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx),
        back_prop=False,
        maximum_iterations=num_sampled_points)

    sampled_idx = sampled_idx.stack()  # num_sampled_points x n
    sampled_idx = tf.transpose(sampled_idx, [1, 0])

    if isinstance(batch_size, int) and isinstance(num_sampled_points, int):
        sampled_idx.set_shape((batch_size, num_sampled_points))

    return sampled_idx, closest_idx
Exemplo n.º 15
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply group normalization.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor with shape [batch_size, height, width, channel].
      paddings: The paddings tensor with shape [batch_size, height]. Intended to
        be used for sequence processing where `height` is `time`.

    Returns:
      A single tensor as the output after applying group normalization, with
      the same shape as 'inputs'. Or a output, output_paddings pair if input
      paddings is not None.
    """
        p = self.params
        inputs = py_utils.with_dependencies([
            py_utils.assert_greater_equal(py_utils.GetRank(inputs),
                                          p.input_rank)
        ], inputs)

        min_group_size = min(p.min_group_size, p.dim)
        group_size = max(p.dim // p.num_groups, min_group_size)
        num_groups = p.dim // group_size

        input_shape = py_utils.GetShape(inputs)
        with tf.name_scope(p.name):
            x = tf.reshape(inputs, input_shape[:-1] + [num_groups, group_size])
            expanded_rank = p.input_rank + 1
            all_dims = list(range(expanded_rank))
            if paddings is None:
                # Skip d0, d[-2]
                axes = all_dims[1:-2] + all_dims[-1:]
                counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics(
                    x, axes=axes, keepdims=True)
                norm_mean, norm_variance = tf.nn.normalize_moments(
                    counts, means_ss, variance_ss, None)
            else:
                expanded_paddings = tf.reshape(
                    paddings, input_shape[:2] + [1] * (expanded_rank - 2))
                # skip the batching and group dim
                if p.cumulative:
                    # Skip d0, d1 and d[-2]
                    reduce_over_dims = all_dims[2:-2] + all_dims[-1:]
                    norm_mean, norm_variance = ComputeMomentsWithPadding(
                        x,
                        expanded_paddings,
                        reduce_over_dims=reduce_over_dims,
                        cumulative_axis=1,
                        keepdims=True)
                else:
                    # Skip d0, d[-2]
                    reduce_over_dims = all_dims[1:-2] + all_dims[-1:]
                    norm_mean, norm_variance = ComputeMomentsWithPadding(
                        x, expanded_paddings, reduce_over_dims, keepdims=True)

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta = theta.beta
            gamma = theta.gamma
            n = input_shape[0]
            t = input_shape[1] if p.cumulative else 1
            norm_shape = [n, t, 1, num_groups, 1
                          ] if p.input_rank == 4 else [n, t, num_groups, 1]
            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.cast(0., norm_variance.dtype)),
                    py_utils.assert_shape_match(norm_shape,
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match(norm_shape,
                                                tf.shape(norm_variance)),
            ]):
                x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon)
                x = tf.reshape(x, input_shape)
                gn_output = x * gamma + beta
                gn_output = tf.reshape(gn_output, input_shape)
                if paddings is None:
                    return gn_output
                else:
                    return gn_output, paddings
Exemplo n.º 16
0
    def _StreamMoments(self, inputs, paddings, cached_sum, cached_count,
                       cached_var):
        """Computes mean and variance over the valid data points in inputs.

    Args:
      inputs: [B, T, F, N, G] or [B, T, N, G]
      paddings: [B, T, 1, 1, 1] or [B, T, 1, 1]
      cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1]
      cached_count: same shape as cached_sum.
      cached_var: same shape as cached_sum.

    Returns:
      mean: [B, T, 1, N, 1] or [B, T, N, 1]
      variance: same shape as mean.
      new_cached_sum: same shape as cached_sum.
      new_cached_count: same shape as cached_count.
    """
        tf.logging.vlog(1, 'inputs: %r', inputs)
        tf.logging.vlog(1, 'paddings: %r', paddings)
        tf.logging.vlog(1, 'cached_sum: %r', cached_sum)
        tf.logging.vlog(1, 'cached_count: %r', cached_count)

        inputs = py_utils.ApplyPadding(paddings, inputs, use_select=False)

        input_rank = py_utils.GetRank(inputs)
        assert input_rank is not None, (f'inputs rank must be staic for '
                                        f'{repr(inputs)}')
        reduce_over_dims = list(range(input_rank))
        # Skip B, T, and N. Reduce {F,G} or just G.
        reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:]
        tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims)

        # [B, T, 1, N, 1] or [B, T, N, 1]
        sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True)
        sum_v = tf.math.cumsum(sum_v, axis=1)
        sum_v += cached_sum

        # [B, T, 1, 1, 1] or [B, T, 1, 1]
        mask = tf.cast(1.0 - paddings, inputs.dtype)
        count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True)
        count_v = tf.math.cumsum(count_v, axis=1)
        input_shape = py_utils.GetShape(inputs)
        if input_rank == 4:
            # F * G
            multiplier = input_shape[-1] * input_shape[-3]
        else:
            # G
            multiplier = input_shape[-1]
        count_v *= multiplier
        count_v += cached_count

        tf.logging.vlog(1, 'sum_v: %r', sum_v)
        tf.logging.vlog(1, 'count_v: %r', count_v)

        mean = sum_v / tf.maximum(count_v, 1.0)

        sum_vv = tf.reduce_sum(py_utils.ApplyPadding(
            paddings,
            tf.math.squared_difference(inputs, mean),
            use_select=False),
                               reduce_over_dims,
                               keepdims=True)
        sum_vv = tf.math.cumsum(sum_vv, axis=1)
        sum_vv += cached_var

        cached_sum = sum_v[:, -1:]
        cached_count = count_v[:, -1:]
        cached_var = sum_vv[:, -1:]

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)),
        ], sum_vv / tf.maximum(count_v, 1.0))
        return mean, variance, cached_sum, cached_count, cached_var
Exemplo n.º 17
0
def MakeLocalMask(seq_len,
                  block_size,
                  left_context,
                  right_context,
                  query_stride=1,
                  dtype=tf.float32):
    """Makes the mask tensor for a full sequence.

  The returned mask reflects the given context sizes, where position i
  attends to tokens in the range [i - (left_context-1), i + right_context].

  For example, given seq_len=4, block_size=2, left_context=3, right_context=0,
  the result mask is
  [[[0., 0., 1., 0.], 1st query in 1st block attends 1st key.
  [0., 0., 1., 1.]],  2nd query in 1st block attends 2nd and left keys
  [[1., 1., 1., 0.],  1st query in 2nd block attends 1st and left keys
  [0., 1., 1., 1.]]]  2st query in 2nd block attends 2nd and left keys

  The position i can move by stride, which means queries are pooled by stride.
  For example, given same params and stride=2, the result mask is
  [[[0., 0., 1., 1.]], The pooled query in 1st block attends 1st and 2nd keys
  [[1., 1., 1., 1.]]]  The pooled query in 2st block attends 1st, 2nd and left

  Args:
    seq_len: int or scalar int tensor. Sequence length.
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size.
    right_context: int. Right context size.
    query_stride: int. Query stride for funnel pool.
    dtype: tf.dtype, default is tf.float32.

  Returns:
    A tensor of [num_blocks, block_size//stride, context_size] taking values in
    {0, 1}, where context_size = block_size + (left_context - 1) + right_context
    Element b, i, j is 1 if in the b-th block, the i-th frame can access
    the j-th frame in the context.
  """
    assert block_size % query_stride == 0, (
        f'block_size({block_size}) must be a multiple of '
        f'query_stride({query_stride}).')
    seq_len = py_utils.with_dependencies([
        py_utils.assert_greater_equal(
            seq_len, 1, message='seq_len must be at least 1')
    ], seq_len)

    num_blocks = (seq_len + block_size - 1) // block_size
    context_size = block_size + (left_context - 1) + right_context

    # [num_blocks, block_size]: source positions in the original sequence.
    src_positions = tf.reshape(tf.range(num_blocks * block_size),
                               [num_blocks, block_size])
    # [num_blocks,]: source positions at the start of each block.
    block_start_positions = tf.range(0, num_blocks * block_size, block_size)
    # [context_size]:  positions relative to the block start.
    relative_context_positions = tf.range(context_size) - (left_context - 1)

    # [num_blocks, context_size]: target positions in the original sequence.
    tgt_positions = (block_start_positions[:, tf.newaxis] +
                     relative_context_positions[tf.newaxis, :])
    # [num_blocks, block_size, context_size]: position differences between source-
    # target pairs.
    position_diff = src_positions[:, :,
                                  tf.newaxis] - tgt_positions[:, tf.newaxis, :]
    # [num_blocks, block_size, context_size]: if attention is allowed between
    # source-target pairs.
    valid_atten = tf.math.logical_and(-right_context <= position_diff,
                                      position_diff < left_context)

    # [num_blocks, block_size]: if the source position is valid, not padded.
    valid_src = src_positions < seq_len
    # [num_blocks, context_size]: if the target position is valid, not padded.
    valid_tgt = tf.math.logical_and(0 <= tgt_positions,
                                    tgt_positions < seq_len)

    valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis],
                                       valid_tgt[:, tf.newaxis, :])
    valid_atten = tf.cast(valid_atten, dtype=dtype)

    if query_stride:
        valid_atten = tf.reshape(valid_atten, [
            num_blocks, block_size // query_stride, query_stride, context_size
        ])
        valid_atten = tf.reduce_max(valid_atten, axis=-2)

    return valid_atten
Exemplo n.º 18
0
    def _StreamMoments(self, inputs, paddings, cached_sum, cached_count,
                       cached_var):
        """Computes mean and variance over the valid data points in inputs.

    Args:
      inputs: [B, T, F, N, G] or [B, T, N, G]
      paddings: [B, T, 1, 1, 1] or [B, T, 1, 1]
      cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1]
      cached_count: same shape as cached_sum.
      cached_var: same shape as cached_sum.

    Returns:
      mean: [B, T, 1, N, 1] or [B, T, N, 1]
      variance: same shape as mean.
      new_cached_sum: same shape as cached_sum.
      new_cached_count: same shape as cached_count.
    """
        tf.logging.vlog(1, 'inputs: %r', inputs)
        tf.logging.vlog(1, 'paddings: %r', paddings)
        tf.logging.vlog(1, 'cached_sum: %r', cached_sum)
        tf.logging.vlog(1, 'cached_count: %r', cached_count)

        mask = tf.cast(1.0 - paddings, inputs.dtype)
        inputs *= tf.cast(mask, inputs.dtype)

        input_rank = py_utils.GetRank(inputs)
        assert input_rank is not None, (f'inputs rank must be staic for '
                                        f'{repr(inputs)}')
        reduce_over_dims = list(range(input_rank))
        # Skip B, T, and N. Reduce {F,G} or just G.
        reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:]
        tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims)

        # [B, T, 1, N, 1] or [B, T, N, 1]
        sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True)
        sum_v = tf.math.cumsum(sum_v, axis=1)
        sum_v += cached_sum

        # [B, T, 1, 1, 1] or [B, T, 1, 1]
        count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True)
        count_v = tf.math.cumsum(count_v, axis=1)
        input_shape = py_utils.GetShape(inputs)
        if input_rank == 4:
            # F * G
            multiplier = input_shape[-1] * input_shape[-3]
        else:
            # G
            multiplier = input_shape[-1]
        count_v *= multiplier
        count_v += cached_count
        count_v = tf.maximum(count_v, 1.0)

        tf.logging.vlog(1, 'sum_v: %r', sum_v)
        tf.logging.vlog(1, 'count_v: %r', count_v)

        mean = sum_v / count_v
        if py_utils.FLAGS.tflite_compatible:
            # TfLite doesn't support broadcasting with 5D tensors.
            inputs_shape = py_utils.GetShape(inputs)
            if len(inputs_shape) == 4:
                tiled_mean = tf.tile(mean, [1, 1, 1, inputs_shape[3]])
            else:
                tiled_mean = tf.tile(
                    mean, [1, 1, inputs_shape[2], 1, inputs_shape[4]])
            sum_vv = tf.reduce_sum(tf.math.square(inputs - tiled_mean) * mask,
                                   reduce_over_dims,
                                   keepdims=True)
        else:
            sum_vv = tf.reduce_sum((inputs - mean)**2 * mask,
                                   reduce_over_dims,
                                   keepdims=True)
        sum_vv = tf.math.cumsum(sum_vv, axis=1)
        sum_vv += cached_var

        cached_sum = sum_v[:, -1:]
        cached_count = count_v[:, -1:]
        cached_var = sum_vv[:, -1:]

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)),
        ], sum_vv / count_v)
        return mean, variance, cached_sum, cached_count, cached_var