예제 #1
0
    def apply_fun(params, x, **kwargs):
        bias_value, scale_value = params

        groups_num = num_groups
        if group_size is not None:
            channels = x.shape[-1]
            if channels % group_size != 0:
                raise ValueError(
                    'Number of channels ({}) is not multiple of the '
                    'group size ({}).'.format(channels, group_size))
            groups_num = channels // groups_num

        input_shape = x.shape
        group_shape = x.shape[:-1] + (groups_num, x.shape[-1] // groups_num)

        x = x.reshape(group_shape)

        reduction_axis = [d for d in range(1, x.ndim - 2)] + [x.ndim - 1]

        mean = np.mean(x, axis=reduction_axis, keepdims=True)
        mean_of_squares = np.mean(np.square(x),
                                  axis=reduction_axis,
                                  keepdims=True)
        var = mean_of_squares - np.square(mean)

        x = (x - mean) * lax.rsqrt(var + epsilon)

        x = x.reshape(input_shape)

        if scale and bias:
            return x * scale_value + bias_value
        if scale:
            return x * scale_value
        if bias:
            return x + bias_value
예제 #2
0
  def __call__(self, x):
    """Applies layer normalization on the input.

    Args:
      x: the inputs

    Returns:
      Normalized inputs (the same shape as inputs).
    """
    x = jnp.asarray(x, jnp.float32)
    features = x.shape[-1]
    mean = jnp.mean(x, axis=-1, keepdims=True)
    mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
    var = mean2 - lax.square(mean)
    mul = lax.rsqrt(var + self.epsilon)
    if self.use_scale:
      mul = mul * jnp.asarray(
          self.param('scale', self.scale_init, (features,)),
          self.dtype)
    y = (x - mean) * mul
    if self.use_bias:
      y = y + jnp.asarray(
          self.param('bias', self.bias_init, (features,)),
          self.dtype)
    return jnp.asarray(y, self.dtype)
예제 #3
0
 def forward(self,
             x,
             epsilon=1e-6,
             bias=True,
             scale=True):
   """Applies layer normalization on the input.
   It normalizes the activations of the layer for each given example in a
   batch independently, rather than across a batch like Batch Normalization.
   i.e. applies a transformation that maintains the mean activation within
   each example close to 0 and the activation standard deviation close to 1.
   Args:
     x: the inputs
     epsilon: A small float added to variance to avoid dividing by zero.
     dtype: the dtype of the computation (default: float32).
     bias:  If True, bias (beta) is added.
     scale: If True, multiply by scale (gamma). When the next layer is linear
       (also e.g. nn.relu), this can be disabled since the scaling will be done
       by the next layer.
     bias_init: Initializer for bias, by default, zero.
     scale_init: Initializer for scale, by default, one.
   Returns:
     Normalized inputs (the same shape as inputs).
   """
   features = x.shape[-1]
   mean = jnp.mean(x, axis=-1, keepdims=True)
   mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
   var = mean2 - lax.square(mean)
   mul = lax.rsqrt(var + epsilon)
   if scale:
     mul = mul * self.scale
   y = (x - mean) * mul
   if bias:
     y = y + self.bias
   return y
예제 #4
0
def segment_normalize(data: jnp.ndarray,
                      segment_ids: jnp.ndarray,
                      num_segments: Optional[int] = None,
                      indices_are_sorted: bool = False,
                      unique_indices: bool = False,
                      eps=1e-8):
    """Normalizes data within each segment.

  Args:
    data: values whose z-score normalized values will be calculated.
      segment-wise.
    segment_ids: indices for segments.
    num_segments: total number of segments.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted.
    unique_indices: whether ``segment_ids`` is known to be free of duplicates.
    eps: epsilon for numerical stability.

  Returns:
    array containing data normalized segment-wise.
  """

    means = segment_mean(data,
                         segment_ids,
                         num_segments,
                         indices_are_sorted=indices_are_sorted,
                         unique_indices=unique_indices)[segment_ids]
    variances = segment_variance(data,
                                 segment_ids,
                                 num_segments,
                                 indices_are_sorted=indices_are_sorted,
                                 unique_indices=unique_indices)[segment_ids]
    normalized = (data - means) * lax.rsqrt(
        jnp.maximum(variances, jnp.array(eps, dtype=variances.dtype)))
    return normalized
예제 #5
0
    def __call__(self, x, training: bool):
        """Normalizes the input using batch statistics.
        Args:
            x: the input to be normalized.
        Returns:
            Normalized inputs (the same shape as inputs).
        """
        x = jnp.asarray(x, jnp.float32)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # we detect if we're in initialization via empty variable tree.
        initializing = not self.has_variable('batch_stats', 'mean')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if not training:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            y = y + bias
        return jnp.asarray(y, self.dtype)
예제 #6
0
def _givens_rotation(a, b):
    b_zero = abs(b) == 0
    a_lt_b = abs(a) < abs(b)
    t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
    r = lax.rsqrt(1 + abs(t)**2)
    cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
    sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
    return cs, sn
예제 #7
0
        def quantized_layernorm(x):
            prec = hparams.quant_hparams.prec
            fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec)
            quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant,
                                                     bounds=None)

            def to_quantized(x):
                return quant_ops.to_quantized(x, dtype=dtype)

            # If epsilon is too small to represent in the quantized format, we set it
            # to the minimal representative non-zero value to avoid the possibility of
            # dividing by zero.
            fp_bounds = quantization.fp_cast.get_bounds(
                prec.exp_min, prec.exp_max, prec.sig_bits)
            epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound)
            quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype))

            # If the reciprocal of the quantized number of features is too small to
            # represent in the quantized format, we set it to the minimal
            # representative nonzero value so that the mean and variance are not
            # trivially 0.
            num_features_quantized = to_quantized(
                jnp.array(num_features, dtype=dtype))
            num_features_recip_quantized = to_quantized(
                jnp.reciprocal(num_features_quantized))
            num_features_recip_quantized = jax.lax.cond(
                jax.lax.eq(num_features_recip_quantized,
                           0.0), lambda _: quantized_epsilon,
                lambda _: num_features_recip_quantized, None)

            x_quantized = to_quantized(x)
            x_sum_quantized_reduction = quantization.quantized_sum(
                x_quantized,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sum = to_quantized(x_sum_quantized_reduction)
            mean = to_quantized(x_sum * num_features_recip_quantized)
            x_minus_mean = to_quantized(x - mean)
            x_sq = to_quantized(lax.square(x_minus_mean))
            x_sq_sum_quantized_reduction = quantization.quantized_sum(
                x_sq,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sq_sum = to_quantized(x_sq_sum_quantized_reduction)
            var = to_quantized(x_sq_sum * num_features_recip_quantized)
            # Prevent division by zero.
            var_plus_epsilon = to_quantized(var + quantized_epsilon)
            mul = to_quantized(lax.rsqrt(var_plus_epsilon))
            if self.use_scale:
                quantized_scale_param = to_quantized(scale_param)
                mul = to_quantized(mul * quantized_scale_param)
            y = to_quantized(x_minus_mean * mul)
            if self.use_bias:
                quantized_bias_param = to_quantized(bias_param)
                y = to_quantized(y + quantized_bias_param)
            return y.astype(self.dtype)
예제 #8
0
def batch_norm(scope: Scope,
               x,
               use_running_average=False,
               axis=-1,
               momentum=0.99,
               epsilon=1e-5,
               dtype=jnp.float32,
               bias=True,
               scale=True,
               bias_init=initializers.zeros,
               scale_init=initializers.ones,
               axis_name=None,
               axis_index_groups=None,
               kind='batch_stats'):

    x = jnp.asarray(x, jnp.float32)
    axis = axis if isinstance(axis, tuple) else (axis, )
    axis = _absolute_dims(x.ndim, axis)
    redux = tuple(i for i in range(x.ndim) if i not in axis)

    def pmean(x):
        m = jnp.mean(x, redux, keepdims=True)
        if axis_name is not None:
            m = lax.pmean(m,
                          axis_name=axis_name,
                          axis_index_groups=axis_index_groups)
        return m

    mean = pmean(x)
    squeeze_shape = jnp.squeeze(mean).shape
    mean2 = pmean(jnp.square(x))
    var = mean2 - jnp.square(mean)

    is_init = not scope.has_variable(kind, 'mean')
    ra_mean = scope.variable(kind, 'mean', jnp.zeros, squeeze_shape)
    ra_var = scope.variable(kind, 'var', jnp.ones, squeeze_shape)

    if use_running_average:
        # if ra_mean is not None:
        #   raise ValueError('batch_stats should be provided if use_running_averages=True')
        mean = jnp.reshape(ra_mean.value, mean.shape)
        var = jnp.reshape(ra_var.value, var.shape)
    else:
        if not is_init:
            beta = 1. - momentum
            ra_mean.value += beta * (jnp.squeeze(mean) - ra_mean.value)
            ra_var.value += beta * (jnp.squeeze(var) - ra_var.value)
    y = x - mean
    mul = lax.rsqrt(var + epsilon)
    if scale:
        mul = mul * scope.param('scale', scale_init, squeeze_shape).reshape(
            mean.shape)
    y = y * mul
    if bias:
        y = y + scope.param('bias', bias_init, squeeze_shape).reshape(
            mean.shape)
    return jnp.asarray(y, dtype)
예제 #9
0
파일: functions.py 프로젝트: sts-sadr/jax
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
  """Normalizes an array by subtracting mean and dividing by sqrt(var)."""
  if mean is None:
    mean = jnp.mean(x, axis, keepdims=True)
  if variance is None:
    # this definition is traditionally seen as less accurate than jnp.var's
    # mean((x - mean(x))**2) but may be faster and even, given typical
    # activation distributions and low-precision arithmetic, more accurate
    # when used in neural network normalization layers
    variance = jnp.mean(x**2, axis, keepdims=True) - mean**2
  return (x - mean) * lax.rsqrt(variance + epsilon)
예제 #10
0
    def __call__(self, x):
        """Applies group normalization to the input (arxiv.org/abs/1803.08494).

    Args:
      x: the input of shape N...C, where N is a batch dimension and C is a
        channels dimensions. `...` represents an arbitrary number of extra
        dimensions that are used to accumulate statistics over.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        x = jnp.asarray(x, jnp.float32)
        if ((self.num_groups is None and self.group_size is None) or
            (self.num_groups is not None and self.group_size is not None)):
            raise ValueError('Either `num_groups` or `group_size` should be '
                             'specified, but not both of them.')
        num_groups = self.num_groups

        channels = x.shape[-1]
        if self.group_size is not None:
            if channels % self.group_size != 0:
                raise ValueError(
                    'Number of channels ({}) is not multiple of the '
                    'group size ({}).'.format(channels, self.group_size))
            num_groups = channels // self.group_size

        if num_groups <= 0 or channels % num_groups != 0:
            raise ValueError('Number of groups ({}) does not divide the number'
                             ' of channels ({}).'.format(num_groups, channels))

        input_shape = x.shape
        group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups)
        x = x.reshape(group_shape)

        reduction_axis = [d for d in range(1, x.ndim - 2)] + [x.ndim - 1]
        mean = jnp.mean(x, axis=reduction_axis, keepdims=True)
        mean_of_squares = jnp.mean(jnp.square(x),
                                   axis=reduction_axis,
                                   keepdims=True)
        var = mean_of_squares - jnp.square(mean)
        x = (x - mean) * lax.rsqrt(var + self.epsilon)
        x = x.reshape(input_shape)

        feature_shape = tuple([1
                               for d in input_shape[:-1]] + [input_shape[-1]])
        if self.use_scale:
            x = x * self.param('scale', self.scale_init, feature_shape)
        if self.use_bias:
            x = x + self.param('bias', self.bias_init, feature_shape)

        return x.astype(self.dtype)
예제 #11
0
파일: functions.py 프로젝트: zizai/jax
def normalize(x: Array,
              axis: Optional[Union[int, Tuple[int, ...]]] = -1,
              mean: Optional[Array] = None,
              variance: Optional[Array] = None,
              epsilon: Array = 1e-5) -> Array:
  """Normalizes an array by subtracting mean and dividing by sqrt(var)."""
  if mean is None:
    mean = jnp.mean(x, axis, keepdims=True)
  if variance is None:
    # this definition is traditionally seen as less accurate than jnp.var's
    # mean((x - mean(x))**2) but may be faster and even, given typical
    # activation distributions and low-precision arithmetic, more accurate
    # when used in neural network normalization layers
    variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean)
  return (x - mean) * lax.rsqrt(variance + epsilon)
예제 #12
0
 def unquantized_layernorm(x):
     num_features_recip = jnp.reciprocal(num_features)
     x_sum = jnp.sum(x, axis=-1, keepdims=True)
     mean = x_sum * num_features_recip
     x_minus_mean = x - mean
     x_sq = lax.square(x_minus_mean)
     x_sq_sum = jnp.sum(x_sq, axis=-1, keepdims=True)
     var = x_sq_sum * num_features_recip
     var_plus_epsilon = var + self.epsilon
     mul = lax.rsqrt(var_plus_epsilon)
     if self.use_scale:
         mul = mul * scale_param
     y = x_minus_mean * mul
     if self.use_bias:
         y = y + bias_param
     return y.astype(self.dtype)
예제 #13
0
def segment_normalize(data, segment_ids, num_segments, eps=1e-5):
    """Normalizes data within each segment.

  Args:
    data: values whose z-score normalized values will be calculated.
      segment-wise.
    segment_ids: indices for segments.
    num_segments: total number of segments.
    eps: epsilon for numerical stability.

  Returns:
    array containing data normalized segment-wise.
  """
    means = segment_mean(data, segment_ids, num_segments)[segment_ids]
    variances = segment_variance(data, segment_ids, num_segments)[segment_ids]
    normalized = (data - means) * lax.rsqrt(variances + eps)
    return jnp.nan_to_num(normalized)
예제 #14
0
    def apply(self,
              x,
              num_groups=32,
              group_size=None,
              epsilon=1e-6,
              dtype=jnp.float32,
              bias=True,
              scale=True,
              bias_init=initializers.zeros,
              scale_init=initializers.ones):
        """Applies group normalization to the input (arxiv.org/abs/1803.08494).

    This op is similar to batch normalization, but statistics are shared across
    equally-sized groups of channels and not shared across batch dimension.
    Thus, group normalization does not depend on the batch composition and does
    not require maintaining internal state for storing statistics.

    The user should either specify the total number of channel groups or the
    number of channels per group.

    Args:
      x: the input of shape N...C, where N is a batch dimension and C is a
        channels dimensions. `...` represents an arbitrary number of extra
        dimensions that are used to accumulate statistics over.
      num_groups: the total number of channel groups. The default value of 32 is
        proposed by the original group normalization paper.
      group_size: the number of channels in a group.
      epsilon: A small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  If True, bias (beta) is added.
      scale: If True, multiply by scale (gamma). When the next layer is linear
        (also e.g. nn.relu), this can be disabled since the scaling will be done
        by the next layer.
      bias_init: Initializer for bias, by default, zero.
      scale_init: Initializer for scale, by default, one.

    Returns:
      Normalized inputs (the same shape as inputs).

    """
        x = jnp.asarray(x, jnp.float32)
        if ((num_groups is None and group_size is None)
                or (num_groups is not None and group_size is not None)):
            raise ValueError('Either `num_groups` or `group_size` should be '
                             'specified, but not both of them.')

        if group_size is not None:
            channels = x.shape[-1]
            if channels % group_size != 0:
                raise ValueError(
                    'Number of channels ({}) is not multiple of the '
                    'group size ({}).'.format(channels, group_size))
            num_groups = channels // group_size

        input_shape = x.shape
        group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups)

        x = x.reshape(group_shape)

        reduction_axis = [d for d in range(1, x.ndim - 2)] + [x.ndim - 1]

        mean = jnp.mean(x, axis=reduction_axis, keepdims=True)
        mean_of_squares = jnp.mean(jnp.square(x),
                                   axis=reduction_axis,
                                   keepdims=True)
        var = mean_of_squares - jnp.square(mean)

        x = (x - mean) * lax.rsqrt(var + epsilon)

        x = x.reshape(input_shape)

        feature_shape = tuple([1
                               for d in input_shape[:-1]] + [input_shape[-1]])
        if scale:
            x = x * self.param('scale', feature_shape, scale_init)
        if bias:
            x = x + self.param('bias', feature_shape, bias_init)

        return x.astype(dtype)
예제 #15
0
    def apply(self,
              x,
              batch_stats=None,
              use_running_average=False,
              axis=-1,
              momentum=0.99,
              epsilon=1e-5,
              dtype=jnp.float32,
              bias=True,
              scale=True,
              bias_init=initializers.zeros,
              scale_init=initializers.ones,
              axis_name=None):
        """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats
        will be used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of
        the batch statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  if True, bias (beta) is added.
      scale: if True, multiply by scale (gamma).
        When the next layer is linear (also e.g. nn.relu), this can be disabled
        since the scaling will be done by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).

    Returns:
      Normalized inputs (this same shape as inputs).
    """
        x = jnp.asarray(x, jnp.float32)
        axis = axis if isinstance(axis, tuple) else (axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
        if self.is_stateful() or batch_stats:
            ra_mean = self.state('mean',
                                 reduced_feature_shape,
                                 initializers.zeros,
                                 collection=batch_stats)
            ra_var = self.state('var',
                                reduced_feature_shape,
                                initializers.ones,
                                collection=batch_stats)
        else:
            ra_mean = None
            ra_var = None

        if use_running_average:
            if ra_mean is None:
                raise ValueError('batch_stats should be provided if '
                                 'use_running_averages is True')
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            if axis_name is not None and not self.is_initializing():
                mean = lax.pmean(mean, axis_name=axis_name)

            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if axis_name is not None and not self.is_initializing():
                mean2 = lax.pmean(mean2, axis_name=axis_name)
            var = mean2 - lax.square(mean)

            if ra_mean and not self.is_initializing():
                ra_mean.value = momentum * ra_mean.value + (1 -
                                                            momentum) * mean
                ra_var.value = momentum * ra_var.value + (1 - momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + epsilon)
        if scale:
            mul = mul * self.param('scale', reduced_feature_shape,
                                   scale_init).reshape(feature_shape)
        y = y * mul
        if bias:
            y = y + self.param('bias', reduced_feature_shape,
                               bias_init).reshape(feature_shape)
        return jnp.asarray(y, dtype)
예제 #16
0
    def __call__(self, x, use_running_average: Optional[bool] = None):
        """Normalizes the input using batch statistics.

    NOTE:
    During initialization (when parameters are mutable) the running average
    of the batch statistics will not be updated. Therefore, the inputs
    fed during initialization don't need to match that of the actual input
    distribution and the reduction axis (set with `axis_name`) does not have
    to exist.

    Args:
      x: the input to be normalized.
      use_running_average: if true, the statistics stored in batch_stats
        will be used instead of computing the batch statistics on the input.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        use_running_average = merge_param('use_running_average',
                                          self.use_running_average,
                                          use_running_average)
        x = jnp.asarray(x, jnp.float32)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # see NOTE above on initialization behavior
        initializing = self.is_mutable_collection('params')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if use_running_average:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            y = y + bias
        return jnp.asarray(y, self.dtype)
예제 #17
0
    def __call__(self, x, use_running_average: Optional[bool] = None):
        """Normalizes the input using batch statistics.

    NOTE:
    During initialization (when parameters are mutable) the running average
    of the batch statistics will not be updated. Therefore, the inputs
    fed during initialization don't need to match that of the actual input
    distribution and the reduction axis (set with `axis_name`) does not have
    to exist.

    Args:
      x: the input to be normalized.
      use_running_average: if true, the statistics stored in batch_stats will be
        used instead of computing the batch statistics on the input.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        use_running_average = nn.module.merge_param('use_running_average',
                                                    self.use_running_average,
                                                    use_running_average)

        virtual_batch_size = self.virtual_batch_size
        batch_axis = _get_batch_axis(self.data_format, x, virtual_batch_size,
                                     use_running_average,
                                     self.axis_index_groups)
        if virtual_batch_size is None:
            virtual_batch_size = self.batch_size

        if use_running_average:
            # Virtual batch norm is not used during evaluation, and we cannot
            # guarantee the train and eval batch sizes are the same, so we use a
            # single virtual batch of size batch_size, and take the first element in
            # the running average array, assuming they have been properly synced
            # across their first dim.
            virtual_batch_size = x.shape[batch_axis]

        x = jnp.asarray(x, jnp.float32)
        # Note that this should only ever default to the first case if we are
        # passing in a batch `x` with less examples than `virtual_batch_size`, which
        # should only happen if we are initializing with dummy variables (typically
        # of batch size 2).
        num_sub_batches = max(1, x.shape[batch_axis] // virtual_batch_size)
        input_shape = x.shape
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        # Add an additional axis because we are going to reshape `x` to have a
        # leading dim of size `virtual_batch_size`.
        reduction_axis = tuple(i + 1 for i in range(x.ndim) if i not in axis)
        sub_batched_shape = (
            num_sub_batches,
            *x.shape[:batch_axis],
            # Necessary for when passing in a batch `x` with less examples than
            # `virtual_batch_size, which should only happen if we are initializing
            # with dummy variables (typically of batch size 2).
            min(x.shape[batch_axis], virtual_batch_size),
            *x.shape[batch_axis + 1:])
        x = jnp.reshape(x, sub_batched_shape)
        ra_mean = self.variable('batch_stats', 'batch_norm_running_mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                feature_shape)
        ra_var = self.variable('batch_stats', 'batch_norm_running_var',
                               lambda s: jnp.ones(s, jnp.float32),
                               feature_shape)
        # If using gradient accumulation, use these to accumulate the activations
        # for the current batch before folding them into the running average.
        mean_accumulator = self.variable('batch_stats',
                                         'batch_norm_mean_accumulator',
                                         lambda s: jnp.zeros(s, jnp.float32),
                                         feature_shape)
        mean2_accumulator = self.variable('batch_stats',
                                          'batch_norm_mean2_accumulator',
                                          lambda s: jnp.zeros(s, jnp.float32),
                                          feature_shape)

        # A counter that is used to determine which accumulation pass we are
        # currently in. This will increment from 0 until we have accumulated
        # gradients calculated on `self.total_batch_size` examples. This should only
        # ever be saved on disk as 0 because we only checkpoint after accumulating
        # enough examples to make an update.
        grad_accum_counter = self.variable('batch_stats', 'grad_accum_counter',
                                           lambda s: jnp.zeros(s, jnp.int32),
                                           [])

        # See NOTE above on initialization behavior.
        initializing = self.is_mutable_collection('params')

        if self.total_batch_size is None:
            passes_per_total_batch = 1
        else:
            passes_per_total_batch = self.total_batch_size // self.batch_size

        if use_running_average:
            # Note that we assume that the values across the first axis have been
            # properly synchronized.
            mean = jnp.expand_dims(ra_mean.value, 0)
            var = jnp.expand_dims(ra_var.value, 0)
        else:
            # Shape (num_sub_batches, x.shape[axis]).
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if not initializing:
                mean_accumulator.value += jnp.mean(mean, axis=0)
                mean2_accumulator.value += jnp.mean(mean2, axis=0)
                grad_accum_counter_inc = grad_accum_counter.value + 1
                # This will be 0 for all gradient accumulation passes except for the
                # last one when we have seen enough examples to make an update to the
                # running averages.
                should_update_ra = grad_accum_counter_inc // passes_per_total_batch
                ra_mean_update = (should_update_ra * mean_accumulator.value /
                                  grad_accum_counter_inc)
                ra_mean.value = ((1 - should_update_ra *
                                  (1 - self.momentum)) * ra_mean.value +
                                 (1 - self.momentum) * ra_mean_update)
                ra_var_update = should_update_ra * (
                    mean2_accumulator.value / grad_accum_counter_inc -
                    lax.square(
                        mean_accumulator.value / grad_accum_counter_inc))
                ra_var.value = ((1 - should_update_ra *
                                 (1 - self.momentum)) * ra_var.value +
                                (1 - self.momentum) * ra_var_update)

                grad_accum_counter.value = (grad_accum_counter_inc %
                                            passes_per_total_batch)
                # Reset the activation accumulators every `passes_per_total_batch` steps
                # (np.sign == 0 if grad_accum_counter == 0).
                mean_accumulator.value *= jnp.sign(grad_accum_counter.value)
                mean2_accumulator.value *= jnp.sign(grad_accum_counter.value)

        y = x - mean.reshape((num_sub_batches, *feature_shape))
        mul = lax.rsqrt(
            var.reshape((num_sub_batches, *feature_shape)) + self.epsilon)
        if self.use_scale:
            mul = mul * self.param('scale', self.scale_init,
                                   reduced_feature_shape).reshape(
                                       (1, *feature_shape))
        y = y * mul
        if self.use_bias:
            y = y + self.param('bias', self.bias_init,
                               reduced_feature_shape).reshape(
                                   (1, *feature_shape))
        y = jnp.reshape(y, input_shape)
        return jnp.asarray(y, self.dtype)
예제 #18
0
def pairwise_cosine_similarity(embeddings1, embeddings2, eps=1e-8):
    sq_norm1 = jnp.sum(embeddings1**2, axis=-1, keepdims=True)
    sq_norm2 = jnp.expand_dims(jnp.sum(embeddings2**2, axis=-1), axis=0)
    dot_product = jnp.matmul(embeddings1, embeddings2.transpose())
    inverse_norm = rsqrt(jnp.maximum(sq_norm1 * sq_norm2, eps**2))
    return dot_product * inverse_norm
예제 #19
0
    def apply(self,
              x,
              batch_stats=None,
              use_running_average=False,
              axis=-1,
              momentum=0.99,
              epsilon=1e-5,
              dtype=jnp.float32,
              bias=True,
              scale=True,
              bias_init=initializers.zeros,
              scale_init=initializers.ones,
              axis_name=None,
              axis_index_groups=None,
              virtual_batch_size=None,
              data_format=None):
        """Normalizes the input using batch statistics.

    Forked from the original flax nn.BatchNorm layer, this allows users to have
    multiple EMAs per device, one for each virtual batch size. For example, if
    the per-device batch size is 128 and the user specifies
    `virtual_batch_size=32`, 4 EMAs will be created on each device, each updated
    with 1/4 of the per-device batch on each forward pass.

    WARNING: the multiple per-device EMAs this creates need to be manually
    synchronized within each device before being used for evaluation, or when
    synchronizing batch norm statistic across devices.

    Args:
      x: the input to be normalized.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats
        will be used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of
        the batch statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  if True, bias (beta) is added.
      scale: if True, multiply by scale (gamma).
        When the next layer is linear (also e.g. nn.relu), this can be disabled
        since the scaling will be done by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).
      axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For
        example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
        examples on the first two and last two devices. See `jax.lax.psum` for
        more details.
      virtual_batch_size: the size of the virtual batches to construct on
        each device, which will be used to normalize sub-batches of each
        per-device batch. Will create a running average
        with a leading dim of size `x.shape[batch_axis] // virtual_batch_size`,
        one for each sub-batch. Note that the first dim of each state must be
        synchronized whenever synchronizing batch norm running averages. Must
        evenly divide the per-device batch size (as determined by `x`), and
        cannot be combined with `axis_index_groups`. Passing the default value
        of None will replicate the existing nn.BatchNorm behavior without
        virtual batches.
      data_format: only used when `virtual_batch_size` is set, to determine the
        batch axis.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        batch_axis = _get_batch_axis(data_format, x, virtual_batch_size,
                                     use_running_average, axis_index_groups)
        if virtual_batch_size is None:
            virtual_batch_size = x.shape[batch_axis]

        if use_running_average:
            # Virtual batch norm is not used during evaluation, and we cannot
            # guarantee the train and eval batch sizes are the same, so we use a
            # single virtual batch of size batch_size, and take the first element in
            # the running average array, assuming they have been properly synced
            # across their first dim.
            virtual_batch_size = x.shape[batch_axis]

        x = jnp.asarray(x, jnp.float32)
        num_sub_batches = x.shape[batch_axis] // virtual_batch_size
        input_shape = x.shape
        axis = axis if isinstance(axis, tuple) else (axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        # Add an additional axis because we are going to reshape `x` to have a
        # leading dim of size `virtual_batch_size`.
        reduction_axis = tuple(i + 1 for i in range(x.ndim) if i not in axis)
        sub_batched_shape = (num_sub_batches, *x.shape[:batch_axis],
                             virtual_batch_size, *x.shape[batch_axis + 1:])
        x = jnp.reshape(x, sub_batched_shape)
        if self.is_stateful() or batch_stats:
            ra_means = self.state('batch_norm_running_mean',
                                  (num_sub_batches, *reduced_feature_shape),
                                  initializers.zeros,
                                  collection=batch_stats)
            ra_vars = self.state('batch_norm_running_var',
                                 (num_sub_batches, *reduced_feature_shape),
                                 initializers.ones,
                                 collection=batch_stats)
        else:
            ra_means = None
            ra_vars = None

        if use_running_average:
            if ra_means is None:
                raise ValueError(
                    'when use_running_averages is True '
                    'either use a stateful context or provide batch_stats')
            # Note that we assume that the values across the first axis have been
            # properly synchronized.
            mean = jnp.expand_dims(ra_means.value[0], 0)
            var = jnp.expand_dims(ra_vars.value[0], 0)
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if axis_name is not None and not self.is_initializing():
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=axis_name,
                              axis_index_groups=axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if ra_means and not self.is_initializing():
                ra_means.value = momentum * ra_means.value + (1 -
                                                              momentum) * mean
                ra_vars.value = momentum * ra_vars.value + (1 - momentum) * var

        y = x - mean.reshape((num_sub_batches, *feature_shape))
        mul = lax.rsqrt(
            var.reshape((num_sub_batches, *feature_shape)) + epsilon)
        if scale:
            mul = mul * self.param('scale', reduced_feature_shape,
                                   scale_init).reshape((1, *feature_shape))
        y = y * mul
        if bias:
            y = y + self.param('bias', reduced_feature_shape,
                               bias_init).reshape((1, *feature_shape))
        y = jnp.reshape(y, input_shape)
        return jnp.asarray(y, dtype)
예제 #20
0
    def __call__(self, inputs, use_running_stats=None):
        """Normalizes the input using batch (optional) means and variances.

    Stats are computed over the batch and spherical dimensions: (0, 1, 2).

    Args:
      inputs: An array of dimensions (batch_size, resolution, resolution,
        n_spins_in, n_channels_in).
      use_running_stats: if true, the statistics stored in batch_stats will be
        used instead of computing the batch statistics on the input.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        use_running_stats = nn.module.merge_param("use_running_stats",
                                                  self.use_running_stats,
                                                  use_running_stats)

        # Normalization is independent per spin per channel.
        num_spins, num_channels = inputs.shape[-2:]
        feature_shape = (1, 1, 1, num_spins, num_channels)
        reduced_feature_shape = (num_spins, num_channels)

        initializing = not self.has_variable("batch_stats", "variance")

        running_variance = self.variable("batch_stats", "variance",
                                         lambda s: jnp.ones(s, jnp.float32),
                                         reduced_feature_shape)

        if self.centered:
            running_mean = self.variable("batch_stats", "mean",
                                         lambda s: jnp.zeros(s, jnp.complex64),
                                         reduced_feature_shape)

        if use_running_stats:
            variance = running_variance.value
            if self.centered:
                mean = running_mean.value
        else:
            # Compute the spherical mean over the spherical grid dimensions, then a
            # conventional mean over the batch.
            if self.centered:
                mean = sphere_utils.spin_spherical_mean(inputs)
                mean = jnp.mean(mean, axis=0)
            # Complex variance is E[x x*] - E[x]E[x*].
            # For spin != 0, E[x] should be zero, although due to discretization this
            # is not always true. We only use E[x x*] here.
            # E[x x*]:
            mean_abs_squared = sphere_utils.spin_spherical_mean(inputs *
                                                                inputs.conj())
            mean_abs_squared = jnp.mean(mean_abs_squared, axis=0)
            # Aggregate means over devices.
            if self.axis_name is not None and not initializing:
                if self.centered:
                    mean = lax.pmean(mean, axis_name=self.axis_name)
                mean_abs_squared = lax.pmean(mean_abs_squared,
                                             axis_name=self.axis_name)

            # Imaginary part is negligible.
            variance = mean_abs_squared.real

            if not initializing:
                running_variance.value = (
                    self.momentum * running_variance.value +
                    (1 - self.momentum) * variance)
                if self.centered:
                    running_mean.value = (self.momentum * running_mean.value +
                                          (1 - self.momentum) * mean)

        if self.centered:
            outputs = inputs - mean.reshape(feature_shape)
        else:
            outputs = inputs

        factor = lax.rsqrt(variance.reshape(feature_shape) + self.epsilon)
        if self.use_scale:
            scale = self.param("scale", self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            factor = factor * scale

        outputs = outputs * factor

        if self.use_bias:
            bias = self.param("bias", self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            outputs = outputs + bias

        return outputs