Пример #1
0
    def __call__(self, x):
        """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        x = jnp.asarray(x, jnp.float32)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # we detect if we're in initialization via empty variable tree.
        initializing = not self.has_variable('batch_stats', 'mean')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if self.use_running_average:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            y = y + bias
        return jnp.asarray(y, self.dtype)
Пример #2
0
def pmean(xs, axis_name):
  warnings.warn('use jax.lax.pmean instead',
                DeprecationWarning)
  return lax.pmean(xs, axis_name)
Пример #3
0
def sync_batch_stats(state):
    """Sync the batch statistics across replicas."""
    avg = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
    return state.replace(model_state=avg(state.model_state))
Пример #4
0
def eval_step(model, batch):
  images = batch['image']
  nn_out = model(images, dropout_p=0)
  return {'loss': lax.pmean(neg_log_likelihood_loss(nn_out, images), 'batch')}
Пример #5
0
    def __call__(self, x, use_running_average: Optional[bool] = None):
        """Normalizes the input using batch statistics.

    NOTE:
    During initialization (when parameters are mutable) the running average
    of the batch statistics will not be updated. Therefore, the inputs
    fed during initialization don't need to match that of the actual input
    distribution and the reduction axis (set with `axis_name`) does not have
    to exist.

    Args:
      x: the input to be normalized.
      use_running_average: if true, the statistics stored in batch_stats
        will be used instead of computing the batch statistics on the input.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        use_running_average = merge_param('use_running_average',
                                          self.use_running_average,
                                          use_running_average)
        x = jnp.asarray(x, jnp.float32)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # see NOTE above on initialization behavior
        initializing = self.is_mutable_collection('params')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if use_running_average:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            y = y + bias
        return jnp.asarray(y, self.dtype)
Пример #6
0
def cross_device_avg(pytree):
  return jax.tree_map(lambda x: lax.pmean(x, 'batch'), pytree)
Пример #7
0
def sync_batchnorm_stats(state):
  # TODO(jekbradbury): use different formula for running variances?
  return lax.pmean(state, axis_name='batch')