def apply_fun(params, x, **kwargs): bias_value, scale_value = params groups_num = num_groups if group_size is not None: channels = x.shape[-1] if channels % group_size != 0: raise ValueError( 'Number of channels ({}) is not multiple of the ' 'group size ({}).'.format(channels, group_size)) groups_num = channels // groups_num input_shape = x.shape group_shape = x.shape[:-1] + (groups_num, x.shape[-1] // groups_num) x = x.reshape(group_shape) reduction_axis = [d for d in range(1, x.ndim - 2)] + [x.ndim - 1] mean = np.mean(x, axis=reduction_axis, keepdims=True) mean_of_squares = np.mean(np.square(x), axis=reduction_axis, keepdims=True) var = mean_of_squares - np.square(mean) x = (x - mean) * lax.rsqrt(var + epsilon) x = x.reshape(input_shape) if scale and bias: return x * scale_value + bias_value if scale: return x * scale_value if bias: return x + bias_value
def __call__(self, x): """Applies layer normalization on the input. Args: x: the inputs Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) features = x.shape[-1] mean = jnp.mean(x, axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) var = mean2 - lax.square(mean) mul = lax.rsqrt(var + self.epsilon) if self.use_scale: mul = mul * jnp.asarray( self.param('scale', self.scale_init, (features,)), self.dtype) y = (x - mean) * mul if self.use_bias: y = y + jnp.asarray( self.param('bias', self.bias_init, (features,)), self.dtype) return jnp.asarray(y, self.dtype)
def forward(self, x, epsilon=1e-6, bias=True, scale=True): """Applies layer normalization on the input. It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. Args: x: the inputs epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: If True, bias (beta) is added. scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. Returns: Normalized inputs (the same shape as inputs). """ features = x.shape[-1] mean = jnp.mean(x, axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) var = mean2 - lax.square(mean) mul = lax.rsqrt(var + epsilon) if scale: mul = mul * self.scale y = (x - mean) * mul if bias: y = y + self.bias return y
def segment_normalize(data: jnp.ndarray, segment_ids: jnp.ndarray, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, eps=1e-8): """Normalizes data within each segment. Args: data: values whose z-score normalized values will be calculated. segment-wise. segment_ids: indices for segments. num_segments: total number of segments. indices_are_sorted: whether ``segment_ids`` is known to be sorted. unique_indices: whether ``segment_ids`` is known to be free of duplicates. eps: epsilon for numerical stability. Returns: array containing data normalized segment-wise. """ means = segment_mean(data, segment_ids, num_segments, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)[segment_ids] variances = segment_variance(data, segment_ids, num_segments, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)[segment_ids] normalized = (data - means) * lax.rsqrt( jnp.maximum(variances, jnp.array(eps, dtype=variances.dtype))) return normalized
def __call__(self, x, training: bool): """Normalizes the input using batch statistics. Args: x: the input to be normalized. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) axis = self.axis if isinstance(self.axis, tuple) else (self.axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) # we detect if we're in initialization via empty variable tree. initializing = not self.has_variable('batch_stats', 'mean') ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) if not training: mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if self.axis_name is not None and not initializing: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) var = mean2 - lax.square(mean) if not initializing: ra_mean.value = self.momentum * ra_mean.value + ( 1 - self.momentum) * mean ra_var.value = self.momentum * ra_var.value + ( 1 - self.momentum) * var y = x - mean.reshape(feature_shape) mul = lax.rsqrt(var + self.epsilon) if self.use_scale: scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) mul = mul * scale y = y * mul if self.use_bias: bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) y = y + bias return jnp.asarray(y, self.dtype)
def _givens_rotation(a, b): b_zero = abs(b) == 0 a_lt_b = abs(a) < abs(b) t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a) r = lax.rsqrt(1 + abs(t)**2) cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r)) sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t)) return cs, sn
def quantized_layernorm(x): prec = hparams.quant_hparams.prec fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None) def to_quantized(x): return quant_ops.to_quantized(x, dtype=dtype) # If epsilon is too small to represent in the quantized format, we set it # to the minimal representative non-zero value to avoid the possibility of # dividing by zero. fp_bounds = quantization.fp_cast.get_bounds( prec.exp_min, prec.exp_max, prec.sig_bits) epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound) quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype)) # If the reciprocal of the quantized number of features is too small to # represent in the quantized format, we set it to the minimal # representative nonzero value so that the mean and variance are not # trivially 0. num_features_quantized = to_quantized( jnp.array(num_features, dtype=dtype)) num_features_recip_quantized = to_quantized( jnp.reciprocal(num_features_quantized)) num_features_recip_quantized = jax.lax.cond( jax.lax.eq(num_features_recip_quantized, 0.0), lambda _: quantized_epsilon, lambda _: num_features_recip_quantized, None) x_quantized = to_quantized(x) x_sum_quantized_reduction = quantization.quantized_sum( x_quantized, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sum = to_quantized(x_sum_quantized_reduction) mean = to_quantized(x_sum * num_features_recip_quantized) x_minus_mean = to_quantized(x - mean) x_sq = to_quantized(lax.square(x_minus_mean)) x_sq_sum_quantized_reduction = quantization.quantized_sum( x_sq, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sq_sum = to_quantized(x_sq_sum_quantized_reduction) var = to_quantized(x_sq_sum * num_features_recip_quantized) # Prevent division by zero. var_plus_epsilon = to_quantized(var + quantized_epsilon) mul = to_quantized(lax.rsqrt(var_plus_epsilon)) if self.use_scale: quantized_scale_param = to_quantized(scale_param) mul = to_quantized(mul * quantized_scale_param) y = to_quantized(x_minus_mean * mul) if self.use_bias: quantized_bias_param = to_quantized(bias_param) y = to_quantized(y + quantized_bias_param) return y.astype(self.dtype)
def batch_norm(scope: Scope, x, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros, scale_init=initializers.ones, axis_name=None, axis_index_groups=None, kind='batch_stats'): x = jnp.asarray(x, jnp.float32) axis = axis if isinstance(axis, tuple) else (axis, ) axis = _absolute_dims(x.ndim, axis) redux = tuple(i for i in range(x.ndim) if i not in axis) def pmean(x): m = jnp.mean(x, redux, keepdims=True) if axis_name is not None: m = lax.pmean(m, axis_name=axis_name, axis_index_groups=axis_index_groups) return m mean = pmean(x) squeeze_shape = jnp.squeeze(mean).shape mean2 = pmean(jnp.square(x)) var = mean2 - jnp.square(mean) is_init = not scope.has_variable(kind, 'mean') ra_mean = scope.variable(kind, 'mean', jnp.zeros, squeeze_shape) ra_var = scope.variable(kind, 'var', jnp.ones, squeeze_shape) if use_running_average: # if ra_mean is not None: # raise ValueError('batch_stats should be provided if use_running_averages=True') mean = jnp.reshape(ra_mean.value, mean.shape) var = jnp.reshape(ra_var.value, var.shape) else: if not is_init: beta = 1. - momentum ra_mean.value += beta * (jnp.squeeze(mean) - ra_mean.value) ra_var.value += beta * (jnp.squeeze(var) - ra_var.value) y = x - mean mul = lax.rsqrt(var + epsilon) if scale: mul = mul * scope.param('scale', scale_init, squeeze_shape).reshape( mean.shape) y = y * mul if bias: y = y + scope.param('bias', bias_init, squeeze_shape).reshape( mean.shape) return jnp.asarray(y, dtype)
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5): """Normalizes an array by subtracting mean and dividing by sqrt(var).""" if mean is None: mean = jnp.mean(x, axis, keepdims=True) if variance is None: # this definition is traditionally seen as less accurate than jnp.var's # mean((x - mean(x))**2) but may be faster and even, given typical # activation distributions and low-precision arithmetic, more accurate # when used in neural network normalization layers variance = jnp.mean(x**2, axis, keepdims=True) - mean**2 return (x - mean) * lax.rsqrt(variance + epsilon)
def __call__(self, x): """Applies group normalization to the input (arxiv.org/abs/1803.08494). Args: x: the input of shape N...C, where N is a batch dimension and C is a channels dimensions. `...` represents an arbitrary number of extra dimensions that are used to accumulate statistics over. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) if ((self.num_groups is None and self.group_size is None) or (self.num_groups is not None and self.group_size is not None)): raise ValueError('Either `num_groups` or `group_size` should be ' 'specified, but not both of them.') num_groups = self.num_groups channels = x.shape[-1] if self.group_size is not None: if channels % self.group_size != 0: raise ValueError( 'Number of channels ({}) is not multiple of the ' 'group size ({}).'.format(channels, self.group_size)) num_groups = channels // self.group_size if num_groups <= 0 or channels % num_groups != 0: raise ValueError('Number of groups ({}) does not divide the number' ' of channels ({}).'.format(num_groups, channels)) input_shape = x.shape group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups) x = x.reshape(group_shape) reduction_axis = [d for d in range(1, x.ndim - 2)] + [x.ndim - 1] mean = jnp.mean(x, axis=reduction_axis, keepdims=True) mean_of_squares = jnp.mean(jnp.square(x), axis=reduction_axis, keepdims=True) var = mean_of_squares - jnp.square(mean) x = (x - mean) * lax.rsqrt(var + self.epsilon) x = x.reshape(input_shape) feature_shape = tuple([1 for d in input_shape[:-1]] + [input_shape[-1]]) if self.use_scale: x = x * self.param('scale', self.scale_init, feature_shape) if self.use_bias: x = x + self.param('bias', self.bias_init, feature_shape) return x.astype(self.dtype)
def normalize(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, mean: Optional[Array] = None, variance: Optional[Array] = None, epsilon: Array = 1e-5) -> Array: """Normalizes an array by subtracting mean and dividing by sqrt(var).""" if mean is None: mean = jnp.mean(x, axis, keepdims=True) if variance is None: # this definition is traditionally seen as less accurate than jnp.var's # mean((x - mean(x))**2) but may be faster and even, given typical # activation distributions and low-precision arithmetic, more accurate # when used in neural network normalization layers variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean) return (x - mean) * lax.rsqrt(variance + epsilon)
def unquantized_layernorm(x): num_features_recip = jnp.reciprocal(num_features) x_sum = jnp.sum(x, axis=-1, keepdims=True) mean = x_sum * num_features_recip x_minus_mean = x - mean x_sq = lax.square(x_minus_mean) x_sq_sum = jnp.sum(x_sq, axis=-1, keepdims=True) var = x_sq_sum * num_features_recip var_plus_epsilon = var + self.epsilon mul = lax.rsqrt(var_plus_epsilon) if self.use_scale: mul = mul * scale_param y = x_minus_mean * mul if self.use_bias: y = y + bias_param return y.astype(self.dtype)
def segment_normalize(data, segment_ids, num_segments, eps=1e-5): """Normalizes data within each segment. Args: data: values whose z-score normalized values will be calculated. segment-wise. segment_ids: indices for segments. num_segments: total number of segments. eps: epsilon for numerical stability. Returns: array containing data normalized segment-wise. """ means = segment_mean(data, segment_ids, num_segments)[segment_ids] variances = segment_variance(data, segment_ids, num_segments)[segment_ids] normalized = (data - means) * lax.rsqrt(variances + eps) return jnp.nan_to_num(normalized)
def apply(self, x, num_groups=32, group_size=None, epsilon=1e-6, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros, scale_init=initializers.ones): """Applies group normalization to the input (arxiv.org/abs/1803.08494). This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group. Args: x: the input of shape N...C, where N is a batch dimension and C is a channels dimensions. `...` represents an arbitrary number of extra dimensions that are used to accumulate statistics over. num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: If True, bias (beta) is added. scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) if ((num_groups is None and group_size is None) or (num_groups is not None and group_size is not None)): raise ValueError('Either `num_groups` or `group_size` should be ' 'specified, but not both of them.') if group_size is not None: channels = x.shape[-1] if channels % group_size != 0: raise ValueError( 'Number of channels ({}) is not multiple of the ' 'group size ({}).'.format(channels, group_size)) num_groups = channels // group_size input_shape = x.shape group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups) x = x.reshape(group_shape) reduction_axis = [d for d in range(1, x.ndim - 2)] + [x.ndim - 1] mean = jnp.mean(x, axis=reduction_axis, keepdims=True) mean_of_squares = jnp.mean(jnp.square(x), axis=reduction_axis, keepdims=True) var = mean_of_squares - jnp.square(mean) x = (x - mean) * lax.rsqrt(var + epsilon) x = x.reshape(input_shape) feature_shape = tuple([1 for d in input_shape[:-1]] + [input_shape[-1]]) if scale: x = x * self.param('scale', feature_shape, scale_init) if bias: x = x + self.param('bias', feature_shape, bias_init) return x.astype(dtype)
def apply(self, x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros, scale_init=initializers.ones, axis_name=None): """Normalizes the input using batch statistics. Args: x: the input to be normalized. batch_stats: a `flax.nn.Collection` used to store an exponential moving average of the batch statistics (default: None). use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: if True, bias (beta) is added. scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). Returns: Normalized inputs (this same shape as inputs). """ x = jnp.asarray(x, jnp.float32) axis = axis if isinstance(axis, tuple) else (axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) if self.is_stateful() or batch_stats: ra_mean = self.state('mean', reduced_feature_shape, initializers.zeros, collection=batch_stats) ra_var = self.state('var', reduced_feature_shape, initializers.ones, collection=batch_stats) else: ra_mean = None ra_var = None if use_running_average: if ra_mean is None: raise ValueError('batch_stats should be provided if ' 'use_running_averages is True') mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): mean = lax.pmean(mean, axis_name=axis_name) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): mean2 = lax.pmean(mean2, axis_name=axis_name) var = mean2 - lax.square(mean) if ra_mean and not self.is_initializing(): ra_mean.value = momentum * ra_mean.value + (1 - momentum) * mean ra_var.value = momentum * ra_var.value + (1 - momentum) * var y = x - mean.reshape(feature_shape) mul = lax.rsqrt(var + epsilon) if scale: mul = mul * self.param('scale', reduced_feature_shape, scale_init).reshape(feature_shape) y = y * mul if bias: y = y + self.param('bias', reduced_feature_shape, bias_init).reshape(feature_shape) return jnp.asarray(y, dtype)
def __call__(self, x, use_running_average: Optional[bool] = None): """Normalizes the input using batch statistics. NOTE: During initialization (when parameters are mutable) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don't need to match that of the actual input distribution and the reduction axis (set with `axis_name`) does not have to exist. Args: x: the input to be normalized. use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. Returns: Normalized inputs (the same shape as inputs). """ use_running_average = merge_param('use_running_average', self.use_running_average, use_running_average) x = jnp.asarray(x, jnp.float32) axis = self.axis if isinstance(self.axis, tuple) else (self.axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) # see NOTE above on initialization behavior initializing = self.is_mutable_collection('params') ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) if use_running_average: mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if self.axis_name is not None and not initializing: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) var = mean2 - lax.square(mean) if not initializing: ra_mean.value = self.momentum * ra_mean.value + ( 1 - self.momentum) * mean ra_var.value = self.momentum * ra_var.value + ( 1 - self.momentum) * var y = x - mean.reshape(feature_shape) mul = lax.rsqrt(var + self.epsilon) if self.use_scale: scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) mul = mul * scale y = y * mul if self.use_bias: bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) y = y + bias return jnp.asarray(y, self.dtype)
def __call__(self, x, use_running_average: Optional[bool] = None): """Normalizes the input using batch statistics. NOTE: During initialization (when parameters are mutable) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don't need to match that of the actual input distribution and the reduction axis (set with `axis_name`) does not have to exist. Args: x: the input to be normalized. use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. Returns: Normalized inputs (the same shape as inputs). """ use_running_average = nn.module.merge_param('use_running_average', self.use_running_average, use_running_average) virtual_batch_size = self.virtual_batch_size batch_axis = _get_batch_axis(self.data_format, x, virtual_batch_size, use_running_average, self.axis_index_groups) if virtual_batch_size is None: virtual_batch_size = self.batch_size if use_running_average: # Virtual batch norm is not used during evaluation, and we cannot # guarantee the train and eval batch sizes are the same, so we use a # single virtual batch of size batch_size, and take the first element in # the running average array, assuming they have been properly synced # across their first dim. virtual_batch_size = x.shape[batch_axis] x = jnp.asarray(x, jnp.float32) # Note that this should only ever default to the first case if we are # passing in a batch `x` with less examples than `virtual_batch_size`, which # should only happen if we are initializing with dummy variables (typically # of batch size 2). num_sub_batches = max(1, x.shape[batch_axis] // virtual_batch_size) input_shape = x.shape axis = self.axis if isinstance(self.axis, tuple) else (self.axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) # Add an additional axis because we are going to reshape `x` to have a # leading dim of size `virtual_batch_size`. reduction_axis = tuple(i + 1 for i in range(x.ndim) if i not in axis) sub_batched_shape = ( num_sub_batches, *x.shape[:batch_axis], # Necessary for when passing in a batch `x` with less examples than # `virtual_batch_size, which should only happen if we are initializing # with dummy variables (typically of batch size 2). min(x.shape[batch_axis], virtual_batch_size), *x.shape[batch_axis + 1:]) x = jnp.reshape(x, sub_batched_shape) ra_mean = self.variable('batch_stats', 'batch_norm_running_mean', lambda s: jnp.zeros(s, jnp.float32), feature_shape) ra_var = self.variable('batch_stats', 'batch_norm_running_var', lambda s: jnp.ones(s, jnp.float32), feature_shape) # If using gradient accumulation, use these to accumulate the activations # for the current batch before folding them into the running average. mean_accumulator = self.variable('batch_stats', 'batch_norm_mean_accumulator', lambda s: jnp.zeros(s, jnp.float32), feature_shape) mean2_accumulator = self.variable('batch_stats', 'batch_norm_mean2_accumulator', lambda s: jnp.zeros(s, jnp.float32), feature_shape) # A counter that is used to determine which accumulation pass we are # currently in. This will increment from 0 until we have accumulated # gradients calculated on `self.total_batch_size` examples. This should only # ever be saved on disk as 0 because we only checkpoint after accumulating # enough examples to make an update. grad_accum_counter = self.variable('batch_stats', 'grad_accum_counter', lambda s: jnp.zeros(s, jnp.int32), []) # See NOTE above on initialization behavior. initializing = self.is_mutable_collection('params') if self.total_batch_size is None: passes_per_total_batch = 1 else: passes_per_total_batch = self.total_batch_size // self.batch_size if use_running_average: # Note that we assume that the values across the first axis have been # properly synchronized. mean = jnp.expand_dims(ra_mean.value, 0) var = jnp.expand_dims(ra_var.value, 0) else: # Shape (num_sub_batches, x.shape[axis]). mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if self.axis_name is not None and not initializing: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) var = mean2 - lax.square(mean) if not initializing: mean_accumulator.value += jnp.mean(mean, axis=0) mean2_accumulator.value += jnp.mean(mean2, axis=0) grad_accum_counter_inc = grad_accum_counter.value + 1 # This will be 0 for all gradient accumulation passes except for the # last one when we have seen enough examples to make an update to the # running averages. should_update_ra = grad_accum_counter_inc // passes_per_total_batch ra_mean_update = (should_update_ra * mean_accumulator.value / grad_accum_counter_inc) ra_mean.value = ((1 - should_update_ra * (1 - self.momentum)) * ra_mean.value + (1 - self.momentum) * ra_mean_update) ra_var_update = should_update_ra * ( mean2_accumulator.value / grad_accum_counter_inc - lax.square( mean_accumulator.value / grad_accum_counter_inc)) ra_var.value = ((1 - should_update_ra * (1 - self.momentum)) * ra_var.value + (1 - self.momentum) * ra_var_update) grad_accum_counter.value = (grad_accum_counter_inc % passes_per_total_batch) # Reset the activation accumulators every `passes_per_total_batch` steps # (np.sign == 0 if grad_accum_counter == 0). mean_accumulator.value *= jnp.sign(grad_accum_counter.value) mean2_accumulator.value *= jnp.sign(grad_accum_counter.value) y = x - mean.reshape((num_sub_batches, *feature_shape)) mul = lax.rsqrt( var.reshape((num_sub_batches, *feature_shape)) + self.epsilon) if self.use_scale: mul = mul * self.param('scale', self.scale_init, reduced_feature_shape).reshape( (1, *feature_shape)) y = y * mul if self.use_bias: y = y + self.param('bias', self.bias_init, reduced_feature_shape).reshape( (1, *feature_shape)) y = jnp.reshape(y, input_shape) return jnp.asarray(y, self.dtype)
def pairwise_cosine_similarity(embeddings1, embeddings2, eps=1e-8): sq_norm1 = jnp.sum(embeddings1**2, axis=-1, keepdims=True) sq_norm2 = jnp.expand_dims(jnp.sum(embeddings2**2, axis=-1), axis=0) dot_product = jnp.matmul(embeddings1, embeddings2.transpose()) inverse_norm = rsqrt(jnp.maximum(sq_norm1 * sq_norm2, eps**2)) return dot_product * inverse_norm
def apply(self, x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros, scale_init=initializers.ones, axis_name=None, axis_index_groups=None, virtual_batch_size=None, data_format=None): """Normalizes the input using batch statistics. Forked from the original flax nn.BatchNorm layer, this allows users to have multiple EMAs per device, one for each virtual batch size. For example, if the per-device batch size is 128 and the user specifies `virtual_batch_size=32`, 4 EMAs will be created on each device, each updated with 1/4 of the per-device batch on each forward pass. WARNING: the multiple per-device EMAs this creates need to be manually synchronized within each device before being used for evaluation, or when synchronizing batch norm statistic across devices. Args: x: the input to be normalized. batch_stats: a `flax.nn.Collection` used to store an exponential moving average of the batch statistics (default: None). use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: if True, bias (beta) is added. scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the examples on the first two and last two devices. See `jax.lax.psum` for more details. virtual_batch_size: the size of the virtual batches to construct on each device, which will be used to normalize sub-batches of each per-device batch. Will create a running average with a leading dim of size `x.shape[batch_axis] // virtual_batch_size`, one for each sub-batch. Note that the first dim of each state must be synchronized whenever synchronizing batch norm running averages. Must evenly divide the per-device batch size (as determined by `x`), and cannot be combined with `axis_index_groups`. Passing the default value of None will replicate the existing nn.BatchNorm behavior without virtual batches. data_format: only used when `virtual_batch_size` is set, to determine the batch axis. Returns: Normalized inputs (the same shape as inputs). """ batch_axis = _get_batch_axis(data_format, x, virtual_batch_size, use_running_average, axis_index_groups) if virtual_batch_size is None: virtual_batch_size = x.shape[batch_axis] if use_running_average: # Virtual batch norm is not used during evaluation, and we cannot # guarantee the train and eval batch sizes are the same, so we use a # single virtual batch of size batch_size, and take the first element in # the running average array, assuming they have been properly synced # across their first dim. virtual_batch_size = x.shape[batch_axis] x = jnp.asarray(x, jnp.float32) num_sub_batches = x.shape[batch_axis] // virtual_batch_size input_shape = x.shape axis = axis if isinstance(axis, tuple) else (axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) # Add an additional axis because we are going to reshape `x` to have a # leading dim of size `virtual_batch_size`. reduction_axis = tuple(i + 1 for i in range(x.ndim) if i not in axis) sub_batched_shape = (num_sub_batches, *x.shape[:batch_axis], virtual_batch_size, *x.shape[batch_axis + 1:]) x = jnp.reshape(x, sub_batched_shape) if self.is_stateful() or batch_stats: ra_means = self.state('batch_norm_running_mean', (num_sub_batches, *reduced_feature_shape), initializers.zeros, collection=batch_stats) ra_vars = self.state('batch_norm_running_var', (num_sub_batches, *reduced_feature_shape), initializers.ones, collection=batch_stats) else: ra_means = None ra_vars = None if use_running_average: if ra_means is None: raise ValueError( 'when use_running_averages is True ' 'either use a stateful context or provide batch_stats') # Note that we assume that the values across the first axis have been # properly synchronized. mean = jnp.expand_dims(ra_means.value[0], 0) var = jnp.expand_dims(ra_vars.value[0], 0) else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=axis_name, axis_index_groups=axis_index_groups), 2) var = mean2 - lax.square(mean) if ra_means and not self.is_initializing(): ra_means.value = momentum * ra_means.value + (1 - momentum) * mean ra_vars.value = momentum * ra_vars.value + (1 - momentum) * var y = x - mean.reshape((num_sub_batches, *feature_shape)) mul = lax.rsqrt( var.reshape((num_sub_batches, *feature_shape)) + epsilon) if scale: mul = mul * self.param('scale', reduced_feature_shape, scale_init).reshape((1, *feature_shape)) y = y * mul if bias: y = y + self.param('bias', reduced_feature_shape, bias_init).reshape((1, *feature_shape)) y = jnp.reshape(y, input_shape) return jnp.asarray(y, dtype)
def __call__(self, inputs, use_running_stats=None): """Normalizes the input using batch (optional) means and variances. Stats are computed over the batch and spherical dimensions: (0, 1, 2). Args: inputs: An array of dimensions (batch_size, resolution, resolution, n_spins_in, n_channels_in). use_running_stats: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. Returns: Normalized inputs (the same shape as inputs). """ use_running_stats = nn.module.merge_param("use_running_stats", self.use_running_stats, use_running_stats) # Normalization is independent per spin per channel. num_spins, num_channels = inputs.shape[-2:] feature_shape = (1, 1, 1, num_spins, num_channels) reduced_feature_shape = (num_spins, num_channels) initializing = not self.has_variable("batch_stats", "variance") running_variance = self.variable("batch_stats", "variance", lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) if self.centered: running_mean = self.variable("batch_stats", "mean", lambda s: jnp.zeros(s, jnp.complex64), reduced_feature_shape) if use_running_stats: variance = running_variance.value if self.centered: mean = running_mean.value else: # Compute the spherical mean over the spherical grid dimensions, then a # conventional mean over the batch. if self.centered: mean = sphere_utils.spin_spherical_mean(inputs) mean = jnp.mean(mean, axis=0) # Complex variance is E[x x*] - E[x]E[x*]. # For spin != 0, E[x] should be zero, although due to discretization this # is not always true. We only use E[x x*] here. # E[x x*]: mean_abs_squared = sphere_utils.spin_spherical_mean(inputs * inputs.conj()) mean_abs_squared = jnp.mean(mean_abs_squared, axis=0) # Aggregate means over devices. if self.axis_name is not None and not initializing: if self.centered: mean = lax.pmean(mean, axis_name=self.axis_name) mean_abs_squared = lax.pmean(mean_abs_squared, axis_name=self.axis_name) # Imaginary part is negligible. variance = mean_abs_squared.real if not initializing: running_variance.value = ( self.momentum * running_variance.value + (1 - self.momentum) * variance) if self.centered: running_mean.value = (self.momentum * running_mean.value + (1 - self.momentum) * mean) if self.centered: outputs = inputs - mean.reshape(feature_shape) else: outputs = inputs factor = lax.rsqrt(variance.reshape(feature_shape) + self.epsilon) if self.use_scale: scale = self.param("scale", self.scale_init, reduced_feature_shape).reshape(feature_shape) factor = factor * scale outputs = outputs * factor if self.use_bias: bias = self.param("bias", self.bias_init, reduced_feature_shape).reshape(feature_shape) outputs = outputs + bias return outputs