def forward(self, x, epsilon=1e-6, bias=True, scale=True): """Applies layer normalization on the input. It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. Args: x: the inputs epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: If True, bias (beta) is added. scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. Returns: Normalized inputs (the same shape as inputs). """ features = x.shape[-1] mean = jnp.mean(x, axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) var = mean2 - lax.square(mean) mul = lax.rsqrt(var + epsilon) if scale: mul = mul * self.scale y = (x - mean) * mul if bias: y = y + self.bias return y
def __call__(self, x): """Applies layer normalization on the input. Args: x: the inputs Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) features = x.shape[-1] mean = jnp.mean(x, axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) var = mean2 - lax.square(mean) mul = lax.rsqrt(var + self.epsilon) if self.use_scale: mul = mul * jnp.asarray( self.param('scale', self.scale_init, (features,)), self.dtype) y = (x - mean) * mul if self.use_bias: y = y + jnp.asarray( self.param('bias', self.bias_init, (features,)), self.dtype) return jnp.asarray(y, self.dtype)
def __call__(self, x, training: bool): """Normalizes the input using batch statistics. Args: x: the input to be normalized. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) axis = self.axis if isinstance(self.axis, tuple) else (self.axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) # we detect if we're in initialization via empty variable tree. initializing = not self.has_variable('batch_stats', 'mean') ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) if not training: mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if self.axis_name is not None and not initializing: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) var = mean2 - lax.square(mean) if not initializing: ra_mean.value = self.momentum * ra_mean.value + ( 1 - self.momentum) * mean ra_var.value = self.momentum * ra_var.value + ( 1 - self.momentum) * var y = x - mean.reshape(feature_shape) mul = lax.rsqrt(var + self.epsilon) if self.use_scale: scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) mul = mul * scale y = y * mul if self.use_bias: bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) y = y + bias return jnp.asarray(y, self.dtype)
def apply_param_gradient(self, step, hyper_params, param, state, grad): beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay # exponential averaging grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction t = step + 1. grad_ema_corr = grad_ema / (1 - beta1 ** t) grad_sq_ema_corr = grad_sq_ema / (1 - beta2 ** t) # new step (both gradient and weight decay) to be applied update = grad_ema_corr / ( jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps ) + weight_decay * param # hypergradient computation and descent # NOTE: here the original paper use the previous update step # we approximate it with the current update step # this is accurate as long as we are using an averaged step # especially since the exponential averaging results in a small lag hypergrad = jnp.vdot(grad, update) learning_rate = state.learning_rate + hypergrad * hyper_params.hypergrad_lr new_param = param - learning_rate * update new_state = _AdamHDParamState(grad_ema, grad_sq_ema, learning_rate) return new_param, new_state
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' lr = hyper_params.learning_rate beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay t = step + 1. rho_inf = 2.0 / (1 - beta2) - 1 grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq beta2_t = beta2**5 grad_sq_ema_corr = grad_sq_ema / (1 - beta2_t) rho_t = rho_inf - 2.0 * t * beta2_t / (1 - beta2_t) if rho_t <= 5: step_size = 1.0 / (1 - beta1**t) else: step_size = lax.sqrt( (1 - beta2_t) * (rho_t - 4) / (rho_inf - 4) * (rho_t - 2) / rho_t * rho_inf / (rho_inf - 2)) / (1 - beta1**t) if rho_t <= 5: new_param = param - lr * step_size * grad_ema new_param -= lr * weight_decay * param else: denom = lax.sqrt(grad_sq_ema_corr) + hyper_params.eps new_param = param - lr * step_size * grad_ema / denom new_param -= lr * weight_decay * param new_state = _RAdamParamState(grad_ema, grad_sq_ema) return new_param, new_state
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, where=None): _check_arraylike("nanvar", a) lax_internal._check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.nanvar is not supported.") a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, a - a_mean) # double-where trick for gradients. if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, keepdims=keepdims, where=where) normalizer = normalizer - ddof normalizer_mask = lax.le(normalizer, 0) result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) return lax.convert_element_type(out, dtype)
def _log_ndtr_lower(x, series_order): """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`.""" dtype = lax.dtype(x).type x_2 = lax.square(x) # Log of the term multiplying (1 + sum) log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * np.log(2. * np.pi)) return log_scale + lax.log(_log_ndtr_asymptotic_series(x, series_order))
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' learning_rate = hyper_params.learning_rate beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 eps = hyper_params.eps weight_decay = hyper_params.weight_decay # exponential moving average for grad² grad_sq = lax.square(grad) grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction bias_correction2 = 1 - beta2**(step + 1.) grad_sq_ema_corr = grad_sq_ema / bias_correction2 # exponential moving average for update tensor update = grad / (jnp.sqrt(grad_sq_ema_corr) + eps) update_ema = beta1 * state.update_ema + ( 1. - beta1) * learning_rate * update # bias correction bias_correction1 = beta1 * state.bias_correction1 + ( 1 - beta1) * learning_rate update_ema_corr = update_ema / bias_correction1 new_param = param - learning_rate * update_ema_corr new_param -= learning_rate * weight_decay * param new_state = _LaPropParamState(update_ema, grad_sq_ema, bias_correction1) return new_param, new_state
def apply_param_gradient(self, step, hyper_params, param, state, grad): beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay learning_rate = hyper_params.learning_rate grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq t = step + 1. grad_ema_corr = grad_ema / (1. - beta1**t) grad_sq_ema_corr = grad_sq_ema / (1. - beta2**t) update = grad_ema_corr / (jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps) if weight_decay != 0.0: update += weight_decay * param if len(param.shape) > 1 and param.shape[0] == self._num_layers: norm_fn = partial(jnp.linalg.norm, keepdims=True) param_norm = jax.vmap(norm_fn)(param) update_norm = jax.vmap(norm_fn)(update) else: param_norm = jnp.linalg.norm(param) update_norm = jnp.linalg.norm(update) trust_ratio = jnp.where( param_norm > 0., jnp.where(update_norm > 0, param_norm / update_norm, 1.0), 1.0) new_param = param - trust_ratio * learning_rate * update new_state = lamb._LAMBParamState(grad_ema, grad_sq_ema) # pylint: disable=protected-access return new_param, new_state
def apply_param_gradient(self, step, hyper_params, param, state, grad): beta = hyper_params.beta weight_decay = hyper_params.weight_decay learning_rate = hyper_params.learning_rate eps = hyper_params.eps # weight decay if not hyper_params.use_adamWStyle_weightDecay: grad += weight_decay * param # gradient accumulation # power(3/2) added such that learning_rate is the actual step size rather than lr / cbrt(lr) weighted_lr = jnp.power(learning_rate, 3./2.) * jnp.sqrt(step + 1) grad_sum = state.grad_sum + weighted_lr * grad grad_sum_sq = state.grad_sum_sq + weighted_lr * lax.square(grad) # parameter update new_param = state.initial_param - grad_sum / (jnp.cbrt(grad_sum_sq) + eps) new_param = beta*param + (1. - beta)*new_param # momentum # AdamW-style weight decay if hyper_params.use_adamWStyle_weightDecay: new_param -= (1. - beta) * learning_rate * weight_decay * param new_state = _MadgradParamState(state.initial_param, grad_sum, grad_sum_sq) return new_param, new_state
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay learning_rate = hyper_params.learning_rate grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq t = step + 1. grad_ema_corr = grad_ema / (1. - beta1**t) grad_sq_ema_corr = grad_sq_ema / (1. - beta2**t) update = grad_ema_corr / (jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps) if weight_decay != 0.0: update += weight_decay * param param_norm = jnp.linalg.norm(param) update_norm = jnp.linalg.norm(update) trust_ratio = jnp.where(param_norm + update_norm > 0., param_norm / update_norm, 1.) new_param = param - trust_ratio * learning_rate * update new_state = _LAMBParamState(grad_ema, grad_sq_ema) return new_param, new_state
def nadamw_update(step, hyper_params, param, state, grad): """Compute the next update using the nadamw optimizer. This term should then be *added* to the parameter value. Args: step: int Current training iteration. hyper_params: NAdamWHyperParams A object containing all of the hyper parameters to perform a step. param: ndarray Current parameter value. state: NAdamWParamState State consiting of EMA of the gradient and gradient squared. grad: ndarray Gradient to use when computing the update. Returns: new_param: ndarray The next parameter value new_state: NAdamWParamState The updated state (gradient and gradient squared) value. """ assert hyper_params.learning_rate is not None, "no learning rate provided." beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 lr = get_cosine_learning_rate_fn(hyper_params.training_steps, hyper_params.learning_rate, hyper_params.min_learning_rate_mult, hyper_params.constant_fraction, hyper_params.warmup_fraction)(step) grad = grad - param * hyper_params.l2_weight_decay grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq t = step + 1. # correction if hyper_params.use_bias_correction: lr_t = lr * jnp.sqrt(1.0 - beta2**t) / (1.0 - beta1**t) else: lr_t = lr if hyper_params.use_nesterov: numerator = (beta1 * grad_ema + (1.0 - beta1) * grad) denom = jnp.sqrt(grad_sq_ema) + hyper_params.epsilon step = lr_t * numerator / denom else: denom = jnp.sqrt(grad_sq_ema) + hyper_params.epsilon step = lr_t * grad_ema / denom step = step + lr_t * hyper_params.adamw_weight_decay * param new_state = NAdamWParamState(grad_ema, grad_sq_ema) return -step, new_state
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay n_sma_threshhold = hyper_params.n_sma_threshhold grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction t = step + 1. grad_ema_corr = grad_ema / (1 - beta1 ** t) grad_sq_ema_corr = grad_sq_ema / (1 - beta2 ** t) # RAdam update n_sma_inf = 2. / (1 - beta2) - 1. n_sma_t = n_sma_inf - (2. * t * beta2 ** t) / (1. - beta2 ** t) # step size computation step_size_num = (n_sma_t - 4.) * (n_sma_t - 2.) * n_sma_inf step_size_denum = (n_sma_inf - 4.) * (n_sma_inf - 2.) * n_sma_t # abs added to deal with negative values in first iterations (happens when the test will ignore step_size anyway) step_size = jnp.sqrt( jnp.abs(step_size_num / step_size_denum) ) denom = jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps # update tensor computation update = jnp.where(n_sma_t > n_sma_threshhold, step_size * grad_ema_corr / denom, grad_ema_corr) new_param = param - hyper_params.learning_rate * update new_param -= hyper_params.learning_rate * weight_decay * param new_state = _RAdamParamState(grad_ema, grad_sq_ema) return new_param, new_state
def hypot(x1, x2): _check_arraylike("hypot", x1, x2) x1, x2 = _promote_dtypes_inexact(x1, x2) x1 = lax.abs(x1) x2 = lax.abs(x2) x1, x2 = maximum(x1, x2), minimum(x1, x2) return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
def cross_entropy_with_logits(logits, targets, z_loss=0.0): """Computes cross entropy loss with stable custom gradient. Computes a stabilized-gradient version of: -jnp.sum(targets * nn.log_softmax(logits), axis=-1) Args: logits: [batch * length, num_classes] float array. targets: categorical one-hot targets [batch * length, num_classes] float array. z_loss: coefficient for auxilliary z-loss loss term. Returns: scalar cross-entropy loss """ max_logit = logits.max(axis=-1, keepdims=True) shifted = logits - max_logit exp_shifted = jnp.exp(shifted) sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True) log_softmax = shifted - jnp.log(sum_exp) loss = -jnp.sum(targets * log_softmax, axis=-1) # Add auxilliary z-loss term. log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) loss += z_loss * lax.square(log_z) return loss
def quantized_layernorm(x): prec = hparams.quant_hparams.prec fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None) def to_quantized(x): return quant_ops.to_quantized(x, dtype=dtype) # If epsilon is too small to represent in the quantized format, we set it # to the minimal representative non-zero value to avoid the possibility of # dividing by zero. fp_bounds = quantization.fp_cast.get_bounds( prec.exp_min, prec.exp_max, prec.sig_bits) epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound) quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype)) # If the reciprocal of the quantized number of features is too small to # represent in the quantized format, we set it to the minimal # representative nonzero value so that the mean and variance are not # trivially 0. num_features_quantized = to_quantized( jnp.array(num_features, dtype=dtype)) num_features_recip_quantized = to_quantized( jnp.reciprocal(num_features_quantized)) num_features_recip_quantized = jax.lax.cond( jax.lax.eq(num_features_recip_quantized, 0.0), lambda _: quantized_epsilon, lambda _: num_features_recip_quantized, None) x_quantized = to_quantized(x) x_sum_quantized_reduction = quantization.quantized_sum( x_quantized, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sum = to_quantized(x_sum_quantized_reduction) mean = to_quantized(x_sum * num_features_recip_quantized) x_minus_mean = to_quantized(x - mean) x_sq = to_quantized(lax.square(x_minus_mean)) x_sq_sum_quantized_reduction = quantization.quantized_sum( x_sq, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sq_sum = to_quantized(x_sq_sum_quantized_reduction) var = to_quantized(x_sq_sum * num_features_recip_quantized) # Prevent division by zero. var_plus_epsilon = to_quantized(var + quantized_epsilon) mul = to_quantized(lax.rsqrt(var_plus_epsilon)) if self.use_scale: quantized_scale_param = to_quantized(scale_param) mul = to_quantized(mul * quantized_scale_param) y = to_quantized(x_minus_mean * mul) if self.use_bias: quantized_bias_param = to_quantized(bias_param) y = to_quantized(y + quantized_bias_param) return y.astype(self.dtype)
def _erf_inv_rule(primals_in, series_in): x, = primals_in series, = series_in u = [x] + series primal_out = lax.erf_inv(x) v = [primal_out] + [None] * len(series) # derivative on co-domain for caching purposes deriv_const = np.sqrt(np.pi) / 2. deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y))) # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities c = [deriv_y(primal_out)] + [None] * (len(series) - 1) tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1) tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1) for k in range(1, len(series)): # we know c[:k], we compute c[k] # propagate c to get v v[k] = fact(k - 1) * sum( _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1)) # propagate v to get next c # square tmp_sq[k] = fact(k) * sum( _scale2(k, j) * v[k - j] * v[j] for j in range(k + 1)) # exp tmp_exp[k] = fact(k - 1) * sum( _scale(k, j) * tmp_exp[k - j] * tmp_sq[j] for j in range(1, k + 1)) # const c[k] = deriv_const * tmp_exp[k] # we can't, and don't need, to compute c[k+1], just need to get the last v[k] k = len(series) v[k] = fact(k - 1) * sum( _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out
def _cross_entropy_with_logits_fwd(logits, targets, z_loss=0.0): """Cross entropy loss forward pass.""" max_logit = logits.max(axis=-1, keepdims=True) shifted = logits - max_logit exp_shifted = jnp.exp(shifted) sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True) log_softmax = shifted - jnp.log(sum_exp) loss = -jnp.sum(targets * log_softmax, axis=-1) # Add auxilliary z-loss term. log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) loss += z_loss * lax.square(log_z) return loss, (logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z)
def _atan2_taylor(primals_in, series_in): x, y = primals_in primal_out = lax.atan2(x, y) x, series = jet(lax.div, primals_in, series_in) c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, )) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' # gets the optimizer parameters learning_rate = hyper_params.learning_rate beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 eps = hyper_params.eps weight_decay = hyper_params.weight_decay beta_lookahead = hyper_params.beta_lookahead lookahead_every_nth_iter = hyper_params.lookahead_every_nth_iter n_sma_threshhold = hyper_params.n_sma_threshhold # Applies gradient centralization grad = _gradient_centralization(grad, use_gc=hyper_params.use_gc) # computes exponential moving averages grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction t = step + 1. grad_ema_corr = grad_ema / (1 - beta1**t) grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t) # RAdam update n_sma_inf = 2. / (1 - beta2) - 1. n_sma_t = n_sma_inf - (2. * t * beta2**t) / (1. - beta2**t) # step size computation step_size_num = (n_sma_t - 4.) * (n_sma_t - 2.) * n_sma_inf step_size_denum = (n_sma_inf - 4.) * (n_sma_inf - 2.) * n_sma_t # abs added to deal with negative values in first iterations (happens when the test will ignore step_size anyway) step_size = jnp.sqrt(jnp.abs(step_size_num / step_size_denum)) denom = jnp.sqrt(grad_sq_ema_corr) + eps # update tensor computation update = jnp.where(n_sma_t > n_sma_threshhold, step_size * grad_ema_corr / denom, grad_ema_corr) # weight decay update += param * weight_decay # applies update new_param = param - update * learning_rate # integrated look ahead (new_param, lookahead_ema) = _lookahead(new_param, state.lookahead_ema, t, beta_lookahead, lookahead_every_nth_iter) new_state = _RangerParamState(grad_ema, grad_sq_ema, lookahead_ema) return new_param, new_state
def unquantized_layernorm(x): num_features_recip = jnp.reciprocal(num_features) x_sum = jnp.sum(x, axis=-1, keepdims=True) mean = x_sum * num_features_recip x_minus_mean = x - mean x_sq = lax.square(x_minus_mean) x_sq_sum = jnp.sum(x_sq, axis=-1, keepdims=True) var = x_sq_sum * num_features_recip var_plus_epsilon = var + self.epsilon mul = lax.rsqrt(var_plus_epsilon) if self.use_scale: mul = mul * scale_param y = x_minus_mean * mul if self.use_bias: y = y + bias_param return y.astype(self.dtype)
def _log_ndtr_asymptotic_series(x, series_order): """Calculates the asymptotic series used in log_ndtr.""" dtype = lax.dtype(x).type if series_order <= 0: return np.array(1, dtype) x_2 = lax.square(x) even_sum = jnp.zeros_like(x) odd_sum = jnp.zeros_like(x) x_2n = x_2 # Start with x^{2*1} = x^{2*n} with n = 1. for n in range(1, series_order + 1): y = np.array(_double_factorial(2 * n - 1), dtype) / x_2n if n % 2: odd_sum += y else: even_sum += y x_2n *= x_2 return dtype(1.) + even_sum - odd_sum
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction t = jnp.array(step + 1, lax.dtype(param.dtype)) grad_ema_corr = grad_ema / (1 - beta1**t) grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t) denom = jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps new_param = param - hyper_params.learning_rate * grad_ema_corr / denom new_param -= hyper_params.learning_rate * weight_decay * param new_state = _AdamParamState(grad_ema, grad_sq_ema) return new_param, new_state
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay smooth = hyper_params.smooth grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction t = step + 1. grad_ema_corr = grad_ema / (1 - beta1**t) grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t) # Sadam alternative to eps denom = smooth_softplus(jnp.sqrt(grad_sq_ema_corr), smooth) new_param = param - hyper_params.learning_rate * grad_ema_corr / denom new_param -= hyper_params.learning_rate * weight_decay * param new_state = _SadamParamState(grad_ema, grad_sq_ema) return new_param, new_state
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' learning_rate = hyper_params.learning_rate beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 eps = hyper_params.eps weight_decay = hyper_params.weight_decay n_sma_threshhold = hyper_params.n_sma_threshhold t = step + 1. # RAdam step size n_sma_inf = 2. / (1 - beta2) - 1. n_sma_t = n_sma_inf - (2. * t * beta2 ** t) / (1. - beta2 ** t) # step size computation step_size_num = (n_sma_t - 4.) * (n_sma_t - 2.) * n_sma_inf step_size_denum = (n_sma_inf - 4.) * (n_sma_inf - 2.) * n_sma_t # we do use the denominator for the first iteration contrary to RAdam step_size = jnp.where(n_sma_t > n_sma_threshhold, jnp.sqrt(step_size_num / step_size_denum), 1.0) step_learning_rate = learning_rate * step_size # LaProp exponential moving average for grad² grad_sq = lax.square(grad) grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction bias_correction2 = 1 - beta2 ** t grad_sq_ema_corr = grad_sq_ema / bias_correction2 # LaProp exponential moving average for update tensor update = grad / (jnp.sqrt(grad_sq_ema_corr) + eps) update_ema = beta1 * state.update_ema + (1. - beta1) * learning_rate * update # bias correction bias_correction1 = beta1 * state.bias_correction1 + (1 - beta1) * learning_rate update_ema_corr = update_ema / bias_correction1 new_param = param - step_learning_rate * update_ema_corr new_param -= learning_rate * weight_decay * param new_state = _RLaPropParamState(update_ema, grad_sq_ema, bias_correction1) return new_param, new_state
def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None): _check_arraylike("var", a) lax_internal._check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.var is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = a.astype(computation_dtype) a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = lax.sub(a, a_mean) if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) normalizer = normalizer - ddof result = sum(centered, axis, keepdims=keepdims, where=where) out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) return lax.convert_element_type(out, dtype)
def apply_param_gradient(self, step, hyper_params, param, state, grad): beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction t = step + 1. grad_ema_corr = grad_ema / (1 - beta1 ** t) grad_sq_ema_corr = grad_sq_ema / (1 - beta2 ** t) # learning rate warmup, to deal with unstability during first iterations exponential_warmup = 1. - jnp.exp( - (1. - beta2) * t ) linear_warmup = jnp.minimum(1., 0.5 * (1. - beta2) * t) learning_rate = hyper_params.learning_rate * jnp.where(hyper_params.use_exponential_warmup, exponential_warmup, linear_warmup) denom = jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps new_param = param - learning_rate * grad_ema_corr / denom new_param -= learning_rate * weight_decay * param new_state = _RAdamSimplifiedParamState(grad_ema, grad_sq_ema) return new_param, new_state
def apply_param_gradient(self, step, hyper_params, param, state, grad): assert hyper_params.learning_rate is not None, 'no learning rate provided.' beta1 = hyper_params.beta1 beta2 = hyper_params.beta2 weight_decay = hyper_params.weight_decay delta = hyper_params.delta wd_ratio = hyper_params.wd_ratio eps = hyper_params.eps grad_sq = lax.square(grad) grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq # bias correction t = step + 1. grad_ema_corr = grad_ema / (1 - beta1**t) grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t) # compute update denom = jnp.sqrt(grad_sq_ema_corr) + eps update = grad_ema_corr / denom # Projection if len(param.shape) > 1: update, wd_ratio = _project_scale_invariant( update, param, grad, delta, wd_ratio, eps) else: wd_ratio = 1. # weight decay new_param = param * ( 1. - hyper_params.learning_rate * weight_decay * wd_ratio) # update step new_param -= hyper_params.learning_rate * update new_state = _AdamPParamState(grad_ema, grad_sq_ema) return new_param, new_state
def apply(self, x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros, scale_init=initializers.ones, axis_name=None): """Normalizes the input using batch statistics. Args: x: the input to be normalized. batch_stats: a `flax.nn.Collection` used to store an exponential moving average of the batch statistics (default: None). use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: if True, bias (beta) is added. scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). Returns: Normalized inputs (this same shape as inputs). """ x = jnp.asarray(x, jnp.float32) axis = axis if isinstance(axis, tuple) else (axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) if self.is_stateful() or batch_stats: ra_mean = self.state('mean', reduced_feature_shape, initializers.zeros, collection=batch_stats) ra_var = self.state('var', reduced_feature_shape, initializers.ones, collection=batch_stats) else: ra_mean = None ra_var = None if use_running_average: if ra_mean is None: raise ValueError('batch_stats should be provided if ' 'use_running_averages is True') mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): mean = lax.pmean(mean, axis_name=axis_name) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): mean2 = lax.pmean(mean2, axis_name=axis_name) var = mean2 - lax.square(mean) if ra_mean and not self.is_initializing(): ra_mean.value = momentum * ra_mean.value + (1 - momentum) * mean ra_var.value = momentum * ra_var.value + (1 - momentum) * var y = x - mean.reshape(feature_shape) mul = lax.rsqrt(var + epsilon) if scale: mul = mul * self.param('scale', reduced_feature_shape, scale_init).reshape(feature_shape) y = y * mul if bias: y = y + self.param('bias', reduced_feature_shape, bias_init).reshape(feature_shape) return jnp.asarray(y, dtype)
def _norm_logpdf(x): neg_half = _constant_like(x, -0.5) log_normalizer = _constant_like(x, _norm_logpdf_constant) return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)