Пример #1
0
    def StreamStep(self, theta, inputs, paddings, state0):
        if py_utils.testonly_skip_norm_layers():
            return inputs, paddings, state0

        p = self.params
        assert p.cumulative
        inputs = py_utils.HasRank(inputs, p.input_rank)

        group_size = self.group_size
        num_groups = self.num_groups
        tf.logging.vlog(1, 'group_size: %s', group_size)
        tf.logging.vlog(1, 'num_groups: %s', num_groups)

        input_shape = py_utils.GetShape(inputs)
        with tf.name_scope(f'{p.name}/StreamStep'):
            x = tf.reshape(inputs, input_shape[:-1] + [num_groups, group_size])
            expanded_rank = p.input_rank + 1
            expanded_paddings = tf.reshape(
                paddings, input_shape[:2] + [1] * (expanded_rank - 2))
            (group_mean, group_variance, cached_sum, cached_count,
             cached_var) = self._StreamMoments(x, expanded_paddings,
                                               state0.cached_sum,
                                               state0.cached_count,
                                               state0.cached_var)
            outputs = self._Normalize(theta, x, group_mean, group_variance)

            return outputs, paddings, py_utils.NestedMap(
                cached_sum=cached_sum,
                cached_count=cached_count,
                cached_var=cached_var)
Пример #2
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1] or [...], the rank is
        either the same as inputs or tf.rank(inputs) - 1.

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        inputs, paddings = self._CastToFPropDtype((inputs, paddings))

        if py_utils.testonly_skip_norm_layers():
            return inputs

        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)

        # shape [..., 1]
        paddings = self._MaybeExpandPaddings(inputs, paddings)

        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings)

            return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean,
                                   norm_variance)
Пример #3
0
    def FProp(self, theta, inputs):
        """Applies batch normalization.

    Using the implementation in github.com/
    tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550

    Args:
      theta: A nested map object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        if py_utils.testonly_skip_norm_layers():
            return inputs

        p = self.params
        inputs_dtype = inputs.dtype
        inputs = tf.cast(inputs, p.dtype)
        inputs = py_utils.with_dependencies([
            py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                        tf.shape(theta.beta))
        ], inputs)
        with tf.name_scope(p.name) as scope:
            if self.do_eval:
                outputs = tf.nn.batch_normalization(inputs, theta.moving_mean,
                                                    theta.moving_variance,
                                                    theta.beta, theta.gamma,
                                                    p.epsilon)
            else:
                mean, variance = self._Moments(inputs, p.bn_group_size)
                mean = py_utils.CheckNumerics(
                    mean, 'mean of {} failed numeric check'.format(scope))
                variance = py_utils.CheckNumerics(
                    variance,
                    'variance of {} failed numeric check'.format(scope))
                outputs = tf.nn.batch_normalization(inputs, mean, variance,
                                                    theta.beta, theta.gamma,
                                                    p.epsilon)
            outputs.set_shape(inputs.get_shape())
            return tf.cast(outputs, inputs_dtype)
Пример #4
0
    def FProp(self, theta, inputs, paddings, class_emb):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [batch, ..., dim].
      paddings: The paddings tensor.  Shaped [batch, ..., 1], with the same rank
        as the input tensor.
      class_emb: The conditioning inputs, Shaped [batch, emb_dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        if py_utils.testonly_skip_norm_layers():
            return inputs

        p = self.params
        batch = py_utils.GetShape(inputs)[0]
        class_emb = py_utils.HasShape(class_emb, [batch, p.class_emb_dim])
        if not py_utils.use_tpu():
            class_emb = py_utils.with_dependencies([
                py_utils.assert_less_equal(
                    tf.cast(class_emb, tf.int32), 1, name='one_hot_assert1'),
                py_utils.assert_greater_equal(
                    tf.cast(class_emb, tf.int32), 0, name='one_hot_assert2'),
                py_utils.assert_equal(tf.ones([batch], tf.int32),
                                      tf.cast(tf.reduce_sum(class_emb, -1),
                                              tf.int32),
                                      name='one_hot_assert3'),
            ], class_emb)

        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings=paddings, class_emb=class_emb)
            return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean,
                                   norm_variance)
Пример #5
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply group normalization.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor with shape [batch_size, height, width, channel].
        if p.rank == 4, else [batch, height, channel].
      paddings: The paddings tensor with shape [batch_size, height]. Intended to
        be used for sequence processing where `height` is `time`.

    Returns:
      A single tensor as the output after applying group normalization, with
      the same shape as 'inputs'. Or a output, output_paddings pair if input
      paddings is not None.
    """
        inputs, paddings = self._CastToFPropDtype((inputs, paddings))

        if py_utils.testonly_skip_norm_layers():
            if paddings is None:
                return inputs
            else:
                return inputs, paddings

        p = self.params
        inputs = py_utils.HasRank(inputs, p.input_rank)
        num_groups = self.num_groups

        input_shape = py_utils.GetShape(inputs)
        with tf.name_scope(p.name):
            x = tf.reshape(inputs,
                           input_shape[:-1] + [num_groups, self.group_size])
            expanded_rank = p.input_rank + 1
            all_dims = list(range(expanded_rank))
            if paddings is None or not p.cumulative:
                # Skips batch and num_groups.
                reduce_over_dims = all_dims[1:-2] + all_dims[-1:]
            else:
                # Skips batch, seqlen and num_groups.
                reduce_over_dims = all_dims[2:-2] + all_dims[-1:]

            if paddings is None and not p.cumulative:
                # Fast path on tpu without reshape.
                counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics(
                    x, axes=reduce_over_dims, keepdims=True)
                group_mean, group_variance = tf.nn.normalize_moments(
                    counts, means_ss, variance_ss, None)
            else:
                expanded_paddings = tf.reshape(
                    paddings, input_shape[:2] + [1] * (expanded_rank - 2))
                group_mean, group_variance = ComputeMoments(
                    x,
                    expanded_paddings,
                    reduce_over_dims,
                    cumulative_axis=1,
                    enable_cross_replica_sum_on_tpu=p.
                    enable_cross_replica_sum_on_tpu,
                    keepdims=True)

            outputs = self._Normalize(theta, x, group_mean, group_variance)

            if paddings is None:
                return outputs
            else:
                return outputs, paddings