def _stats_to_bounds(self, stats_value): """Computes activation clipping bounds from activation statistics.""" hyper = self.hyper maximum = stats_value.mean_batch_maximum minimum = stats_value.mean_batch_minimum mom = jnp.maximum(jnp.abs(maximum), jnp.abs(minimum)) stddev_uncentered = lax.sqrt(stats_value.mean_sq) absdev_uncentered = stats_value.mean_abs stddev = lax.sqrt(stats_value.mean_sq - stats_value.mean**2) abs_mean = jnp.abs(stats_value.mean) if hyper.use_old_code: # old code of computing the bound if hyper.use_cams: # upper confidence bound formula return abs_mean + hyper.stddev_coeff * stddev elif hyper.use_mean_of_max: return mom else: return ( hyper.mix_coeff * hyper.stddev_coeff * stddev_uncentered + (1 - hyper.mix_coeff) * hyper.absdev_coeff * absdev_uncentered) else: # use new way of computing the bound cams = abs_mean + hyper.cams_stddev_coeff * stddev return (hyper.fixed_bound + hyper.mean_of_max_coeff * mom + hyper.stddev_coeff * stddev_uncentered + hyper.absdev_coeff * absdev_uncentered + hyper.cams_coeff * cams)
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' lr = hyper_params.learning_rate beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay t = step + 1. rho_inf = 2.0 / (1 - beta2) - 1 grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq beta2_t = beta2**5 grad_sq_ema_corr = grad_sq_ema / (1 - beta2_t) rho_t = rho_inf - 2.0 * t * beta2_t / (1 - beta2_t) if rho_t <= 5: step_size = 1.0 / (1 - beta1**t) else: step_size = lax.sqrt( (1 - beta2_t) * (rho_t - 4) / (rho_inf - 4) * (rho_t - 2) / rho_t * rho_inf / (rho_inf - 2)) / (1 - beta1**t) if rho_t <= 5: new_param = param - lr * step_size * grad_ema new_param -= lr * weight_decay * param else: denom = lax.sqrt(grad_sq_ema_corr) + hyper_params.eps new_param = param - lr * step_size * grad_ema / denom new_param -= lr * weight_decay * param new_state = _RAdamParamState(grad_ema, grad_sq_ema) return new_param, new_state
def _stats_to_bounds(self, stats_value): """Computes activation clipping bounds from activation statistics.""" hyper = self.hyper if hyper.use_cams: # upper confidence bound formula return jnp.abs(stats_value.mean) + hyper.stddev_coeff * lax.sqrt( stats_value.mean_sq - stats_value.mean**2) elif hyper.use_mean_of_max: maximum = stats_value.mean_batch_maximum minimum = stats_value.mean_batch_minimum return jnp.maximum(jnp.abs(maximum), jnp.abs(minimum)) else: stddev_uncentered = lax.sqrt(stats_value.mean_sq) absdev_uncentered = stats_value.mean_abs return (hyper.mix_coeff * hyper.stddev_coeff * stddev_uncentered + (1 - hyper.mix_coeff) * hyper.absdev_coeff * absdev_uncentered)
def _dct_ortho_norm(out, axis): factor = lax.concatenate([ lax.full((1, ), 4, out.dtype), lax.full((out.shape[axis] - 1, ), 2, out.dtype) ], 0) factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis]) return out / lax.sqrt(factor * out.shape[axis])
def hypot(x1, x2): _check_arraylike("hypot", x1, x2) x1, x2 = _promote_dtypes_inexact(x1, x2) x1 = lax.abs(x1) x2 = lax.abs(x2) x1, x2 = maximum(x1, x2), minimum(x1, x2) return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
def _sqrt2(x): x, xx = x c = lax.sqrt(x) u, uu = _mul12(c, c) cc = (x - u - uu + xx) * 0.5 / c y = c + cc yy = c - y + cc return y, yy
def batchnorm(x, s, bias, mean, var, epsilon=1e-5): dims_x = len(x.shape) dim_ones = (1,) * (dims_x - 2) s = s.reshape(-1, *dim_ones) bias = bias.reshape(-1, *dim_ones) mean = mean.reshape(-1, *dim_ones) var = var.reshape(-1, *dim_ones) ot = s * (x - mean) / lax.sqrt(var + epsilon) + bias return ot
def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, where=None): _check_arraylike("nanstd", a) lax_internal._check_user_dtype_supported(dtype, "nanstd") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.nanstd is not supported.") return lax.sqrt( nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
def _ndtri(p): """Implements ndtri core logic.""" # Constants used in piece-wise rational approximations. Taken from the cephes # library: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html p0 = list( reversed([ -5.99633501014107895267E1, 9.80010754185999661536E1, -5.66762857469070293439E1, 1.39312609387279679503E1, -1.23916583867381258016E0 ])) q0 = list( reversed([ 1.0, 1.95448858338141759834E0, 4.67627912898881538453E0, 8.63602421390890590575E1, -2.25462687854119370527E2, 2.00260212380060660359E2, -8.20372256168333339912E1, 1.59056225126211695515E1, -1.18331621121330003142E0 ])) p1 = list( reversed([ 4.05544892305962419923E0, 3.15251094599893866154E1, 5.71628192246421288162E1, 4.40805073893200834700E1, 1.46849561928858024014E1, 2.18663306850790267539E0, -1.40256079171354495875E-1, -3.50424626827848203418E-2, -8.57456785154685413611E-4 ])) q1 = list( reversed([ 1.0, 1.57799883256466749731E1, 4.53907635128879210584E1, 4.13172038254672030440E1, 1.50425385692907503408E1, 2.50464946208309415979E0, -1.42182922854787788574E-1, -3.80806407691578277194E-2, -9.33259480895457427372E-4 ])) p2 = list( reversed([ 3.23774891776946035970E0, 6.91522889068984211695E0, 3.93881025292474443415E0, 1.33303460815807542389E0, 2.01485389549179081538E-1, 1.23716634817820021358E-2, 3.01581553508235416007E-4, 2.65806974686737550832E-6, 6.23974539184983293730E-9 ])) q2 = list( reversed([ 1.0, 6.02427039364742014255E0, 3.67983563856160859403E0, 1.37702099489081330271E0, 2.16236993594496635890E-1, 1.34204006088543189037E-2, 3.28014464682127739104E-4, 2.89247864745380683936E-6, 6.79019408009981274425E-9 ])) dtype = lax.dtype(p).type shape = jnp.shape(p) def _create_polynomial(var, coeffs): """Compute n_th order polynomial via Horner's method.""" coeffs = np.array(coeffs, dtype) if not coeffs.size: return jnp.zeros_like(var) return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p) # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs # later on. The result from the computation when p == 0 is not used so any # number that doesn't result in NaNs is fine. sanitized_mcp = jnp.where(maybe_complement_p <= dtype(0.), jnp.full(shape, dtype(0.5)), maybe_complement_p) # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2). w = sanitized_mcp - dtype(0.5) ww = lax.square(w) x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0)) x_for_big_p *= -dtype(np.sqrt(2. * np.pi)) # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z), # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different # arrays based on whether p < exp(-32). z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp)) first_term = z - lax.log(z) / z second_term_small_p = (_create_polynomial(dtype(1.) / z, p2) / _create_polynomial(dtype(1.) / z, q2) / z) second_term_otherwise = (_create_polynomial(dtype(1.) / z, p1) / _create_polynomial(dtype(1.) / z, q1) / z) x_for_small_p = first_term - second_term_small_p x_otherwise = first_term - second_term_otherwise x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)), x_for_big_p, jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise)) x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x) infinity = jnp.full(shape, dtype(np.inf)) x_nan_replaced = jnp.where(p <= dtype(0.0), -infinity, jnp.where(p >= dtype(1.0), infinity, x)) return x_nan_replaced
lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x))))) def def_comp(prim, comp): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ jet_rules[prim] = partial(jet, comp) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x**0.5) def_comp(lax.rsqrt_p, lambda x: x**-0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1))) def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x))) def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x)) def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y)) def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b)) def _erf_inv_rule(primals_in, series_in): x, = primals_in series, = series_in u = [x] + series primal_out = lax.erf_inv(x) v = [primal_out] + [None] * len(series)
def apply(self, x, layer=LAYER_EVONORM_B0, nonlinearity=True, num_groups=32, group_size=None, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros, scale_init=initializers.ones, axis_name=None, axis_index_groups=None): """Normalizes the input using batch statistics. Args: x: the input to be normalized. layer: LAYER_EVONORM_B0 or LAYER_EVONORM_S0. nonlinearity: use the EvoNorm nonlinearity. num_groups: number of groups to use for group statistics. group_size: size of groups, see nn.GroupNorm. batch_stats: a `flax.nn.Collection` used to store an exponential moving average of the batch statistics (default: None). use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: if True, bias (beta) is added. scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the examples on the first two and last two devices. See `jax.lax.psum` for more details. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) axis = axis if isinstance(axis, tuple) else (axis,) # pylint: disable=protected-access axis = nn.normalization._absolute_dims(x.ndim, axis) # pylint: enable=protected-access feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) instance_reduction_axis = tuple( i for i in range(x.ndim) if i not in axis and i > 0) batch_reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) if nonlinearity: v = self.param('v', reduced_feature_shape, jax.nn.initializers.ones).reshape(feature_shape) if layer == LAYER_EVONORM_S0: den, group_shape, input_shape = _GroupStd( x, num_groups=num_groups, group_size=group_size, epsilon=epsilon, dtype=dtype, ) x = x * nn.sigmoid(v * x) x = x.reshape(group_shape) x /= den x = x.reshape(input_shape) elif layer == LAYER_EVONORM_B0: if self.is_stateful() or batch_stats: ra_var = self.state( 'var', reduced_feature_shape, initializers.ones, collection=batch_stats) else: ra_var = None if use_running_average: if ra_var is None: raise ValueError( 'when use_running_averages is True ' 'either use a stateful context or provide batch_stats') var = ra_var.value else: mean = jnp.mean(x, axis=batch_reduction_axis, keepdims=False) mean2 = jnp.mean( lax.square(x), axis=batch_reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean( concatenated_mean, axis_name=axis_name, axis_index_groups=axis_index_groups), 2) var = mean2 - lax.square(mean) if ra_var and not self.is_initializing(): ra_var.value = momentum * ra_var.value + (1 - momentum) * var left = lax.sqrt(var + epsilon) instance_std = jnp.sqrt( x.var(axis=instance_reduction_axis, keepdims=True) + epsilon) right = v * x + instance_std x = x / jnp.maximum(left, right) else: raise ValueError('Unknown EvoNorm layer: {}'.format(layer)) if scale: x *= self.param('scale', reduced_feature_shape, scale_init).reshape(feature_shape) if bias: x = x + self.param('bias', reduced_feature_shape, bias_init).reshape(feature_shape) return jnp.asarray(x, dtype)
def apply( self, x, num_groups=32, group_size=None, epsilon=1e-6, dtype=jnp.float32, ): """Applies group normalization to the input (arxiv.org/abs/1803.08494). This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group. Args: x: the input of shape N...C, where N is a batch dimension and C is a channels dimensions. `...` represents an arbitrary number of extra dimensions that are used to accumulate statistics over. num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) if ((num_groups is None and group_size is None) or (num_groups is not None and group_size is not None)): raise ValueError('Either `num_groups` or `group_size` should be ' 'specified, but not both of them.') channels = x.shape[-1] if group_size is not None: if channels % group_size != 0: raise ValueError('Number of channels ({}) is not multiple of the ' 'group size ({}).'.format(channels, group_size)) num_groups = channels // group_size while num_groups > 1: if channels % num_groups == 0: break num_groups -= 1 group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups) input_shape = x.shape x = x.reshape(group_shape) reduction_axis = list(range(1, x.ndim - 2)) + [x.ndim - 1] mean = jnp.mean(x, axis=reduction_axis, keepdims=True) mean_of_squares = jnp.mean( jnp.square(x), axis=reduction_axis, keepdims=True) var = mean_of_squares - jnp.square(mean) std = lax.sqrt(var + epsilon) return std.astype(dtype), group_shape, input_shape
def apply(self, x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, axis_name=None, axis_index_groups=None): """Normalizes the input using batch statistics. Args: x: the input to be normalized. batch_stats: a `flax.nn.Collection` used to store an exponential moving average of the batch statistics (default: None). use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the examples on the first two and last two devices. See `jax.lax.psum` for more details. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) axis = axis if isinstance(axis, tuple) else (axis,) # pylint: disable=protected-access axis = nn.normalization._absolute_dims(x.ndim, axis) # pylint: enable=protected-access reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) if self.is_stateful() or batch_stats: ra_var = self.state( 'var', reduced_feature_shape, initializers.ones, collection=batch_stats) else: ra_var = None if use_running_average: if ra_var is None: raise ValueError('when use_running_averages is True ' 'either use a stateful context or provide batch_stats') var = ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean( concatenated_mean, axis_name=axis_name, axis_index_groups=axis_index_groups), 2) var = mean2 - lax.square(mean) if ra_var and not self.is_initializing(): ra_var.value = momentum * ra_var.value + (1 - momentum) * var mul = lax.sqrt(var + epsilon) return jnp.asarray(mul, dtype)
def __call__(self, x): #(1 + x/np.sqrt(4 + x**2))/2 return lax.mul( 0.5, lax.add(lax.div(x, lax.sqrt(lax.add(lax.square(x), 4.))), 1.))
def __call__(self, x): #(x + np.sqrt(x**2 + 4))/2 return lax.mul(0.5, lax.add(x, lax.sqrt(lax.add(lax.square(x), 4.))))
def squareplus(x): return lax.mul(0.5, lax.add(x, lax.sqrt(lax.add(lax.square(x), 4.0))))