Beispiel #1
0
 def forward(self,
             x,
             epsilon=1e-6,
             bias=True,
             scale=True):
   """Applies layer normalization on the input.
   It normalizes the activations of the layer for each given example in a
   batch independently, rather than across a batch like Batch Normalization.
   i.e. applies a transformation that maintains the mean activation within
   each example close to 0 and the activation standard deviation close to 1.
   Args:
     x: the inputs
     epsilon: A small float added to variance to avoid dividing by zero.
     dtype: the dtype of the computation (default: float32).
     bias:  If True, bias (beta) is added.
     scale: If True, multiply by scale (gamma). When the next layer is linear
       (also e.g. nn.relu), this can be disabled since the scaling will be done
       by the next layer.
     bias_init: Initializer for bias, by default, zero.
     scale_init: Initializer for scale, by default, one.
   Returns:
     Normalized inputs (the same shape as inputs).
   """
   features = x.shape[-1]
   mean = jnp.mean(x, axis=-1, keepdims=True)
   mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
   var = mean2 - lax.square(mean)
   mul = lax.rsqrt(var + epsilon)
   if scale:
     mul = mul * self.scale
   y = (x - mean) * mul
   if bias:
     y = y + self.bias
   return y
Beispiel #2
0
  def __call__(self, x):
    """Applies layer normalization on the input.

    Args:
      x: the inputs

    Returns:
      Normalized inputs (the same shape as inputs).
    """
    x = jnp.asarray(x, jnp.float32)
    features = x.shape[-1]
    mean = jnp.mean(x, axis=-1, keepdims=True)
    mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
    var = mean2 - lax.square(mean)
    mul = lax.rsqrt(var + self.epsilon)
    if self.use_scale:
      mul = mul * jnp.asarray(
          self.param('scale', self.scale_init, (features,)),
          self.dtype)
    y = (x - mean) * mul
    if self.use_bias:
      y = y + jnp.asarray(
          self.param('bias', self.bias_init, (features,)),
          self.dtype)
    return jnp.asarray(y, self.dtype)
    def __call__(self, x, training: bool):
        """Normalizes the input using batch statistics.
        Args:
            x: the input to be normalized.
        Returns:
            Normalized inputs (the same shape as inputs).
        """
        x = jnp.asarray(x, jnp.float32)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # we detect if we're in initialization via empty variable tree.
        initializing = not self.has_variable('batch_stats', 'mean')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if not training:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            y = y + bias
        return jnp.asarray(y, self.dtype)
Beispiel #4
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay

        # exponential averaging
        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        # bias correction
        t = step + 1.
        grad_ema_corr = grad_ema / (1 - beta1 ** t)
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2 ** t)

        # new step (both gradient and weight decay) to be applied
        update = grad_ema_corr / ( jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps ) + weight_decay * param

        # hypergradient computation and descent
        # NOTE: here the original paper use the previous update step
        # we approximate it with the current update step
        # this is accurate as long as we are using an averaged step
        # especially since the exponential averaging results in a small lag
        hypergrad = jnp.vdot(grad, update)
        learning_rate = state.learning_rate + hypergrad * hyper_params.hypergrad_lr

        new_param = param - learning_rate * update
        new_state = _AdamHDParamState(grad_ema, grad_sq_ema, learning_rate)
        return new_param, new_state
Beispiel #5
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        lr = hyper_params.learning_rate
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        t = step + 1.
        rho_inf = 2.0 / (1 - beta2) - 1

        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq
        beta2_t = beta2**5
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2_t)
        rho_t = rho_inf - 2.0 * t * beta2_t / (1 - beta2_t)
        if rho_t <= 5:
            step_size = 1.0 / (1 - beta1**t)
        else:
            step_size = lax.sqrt(
                (1 - beta2_t) * (rho_t - 4) / (rho_inf - 4) *
                (rho_t - 2) / rho_t * rho_inf / (rho_inf - 2)) / (1 - beta1**t)

        if rho_t <= 5:
            new_param = param - lr * step_size * grad_ema
            new_param -= lr * weight_decay * param
        else:
            denom = lax.sqrt(grad_sq_ema_corr) + hyper_params.eps
            new_param = param - lr * step_size * grad_ema / denom
            new_param -= lr * weight_decay * param
        new_state = _RAdamParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
Beispiel #6
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
Beispiel #7
0
def _log_ndtr_lower(x, series_order):
  """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
  dtype = lax.dtype(x).type
  x_2 = lax.square(x)
  # Log of the term multiplying (1 + sum)
  log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * np.log(2. * np.pi))
  return log_scale + lax.log(_log_ndtr_asymptotic_series(x, series_order))
Beispiel #8
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        learning_rate = hyper_params.learning_rate
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        eps = hyper_params.eps
        weight_decay = hyper_params.weight_decay

        # exponential moving average for grad²
        grad_sq = lax.square(grad)
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq
        # bias correction
        bias_correction2 = 1 - beta2**(step + 1.)
        grad_sq_ema_corr = grad_sq_ema / bias_correction2

        # exponential moving average for update tensor
        update = grad / (jnp.sqrt(grad_sq_ema_corr) + eps)
        update_ema = beta1 * state.update_ema + (
            1. - beta1) * learning_rate * update
        # bias correction
        bias_correction1 = beta1 * state.bias_correction1 + (
            1 - beta1) * learning_rate
        update_ema_corr = update_ema / bias_correction1

        new_param = param - learning_rate * update_ema_corr
        new_param -= learning_rate * weight_decay * param
        new_state = _LaPropParamState(update_ema, grad_sq_ema,
                                      bias_correction1)
        return new_param, new_state
Beispiel #9
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        learning_rate = hyper_params.learning_rate

        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        t = step + 1.
        grad_ema_corr = grad_ema / (1. - beta1**t)
        grad_sq_ema_corr = grad_sq_ema / (1. - beta2**t)

        update = grad_ema_corr / (jnp.sqrt(grad_sq_ema_corr) +
                                  hyper_params.eps)

        if weight_decay != 0.0:
            update += weight_decay * param

        if len(param.shape) > 1 and param.shape[0] == self._num_layers:
            norm_fn = partial(jnp.linalg.norm, keepdims=True)
            param_norm = jax.vmap(norm_fn)(param)
            update_norm = jax.vmap(norm_fn)(update)
        else:
            param_norm = jnp.linalg.norm(param)
            update_norm = jnp.linalg.norm(update)

        trust_ratio = jnp.where(
            param_norm > 0.,
            jnp.where(update_norm > 0, param_norm / update_norm, 1.0), 1.0)
        new_param = param - trust_ratio * learning_rate * update
        new_state = lamb._LAMBParamState(grad_ema, grad_sq_ema)  # pylint: disable=protected-access

        return new_param, new_state
Beispiel #10
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        beta = hyper_params.beta
        weight_decay = hyper_params.weight_decay
        learning_rate = hyper_params.learning_rate
        eps = hyper_params.eps

        # weight decay
        if not hyper_params.use_adamWStyle_weightDecay:
            grad += weight_decay * param

        # gradient accumulation
        # power(3/2) added such that learning_rate is the actual step size rather than lr / cbrt(lr)
        weighted_lr = jnp.power(learning_rate, 3./2.) * jnp.sqrt(step + 1)
        grad_sum = state.grad_sum + weighted_lr * grad
        grad_sum_sq = state.grad_sum_sq + weighted_lr * lax.square(grad)

        # parameter update
        new_param = state.initial_param - grad_sum / (jnp.cbrt(grad_sum_sq) + eps)
        new_param = beta*param + (1. - beta)*new_param # momentum

        # AdamW-style weight decay
        if hyper_params.use_adamWStyle_weightDecay:
            new_param -= (1. - beta) * learning_rate * weight_decay * param

        new_state = _MadgradParamState(state.initial_param, grad_sum, grad_sum_sq)
        return new_param, new_state
Beispiel #11
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        learning_rate = hyper_params.learning_rate

        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        t = step + 1.
        grad_ema_corr = grad_ema / (1. - beta1**t)
        grad_sq_ema_corr = grad_sq_ema / (1. - beta2**t)

        update = grad_ema_corr / (jnp.sqrt(grad_sq_ema_corr) +
                                  hyper_params.eps)

        if weight_decay != 0.0:
            update += weight_decay * param

        param_norm = jnp.linalg.norm(param)
        update_norm = jnp.linalg.norm(update)
        trust_ratio = jnp.where(param_norm + update_norm > 0.,
                                param_norm / update_norm, 1.)

        new_param = param - trust_ratio * learning_rate * update
        new_state = _LAMBParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
Beispiel #12
0
def nadamw_update(step, hyper_params, param, state, grad):
    """Compute the next update using the nadamw optimizer.

  This term should then be *added* to the parameter value.

  Args:
    step: int
      Current training iteration.
    hyper_params: NAdamWHyperParams
      A object containing all of the hyper parameters to perform a step.
    param: ndarray
      Current parameter value.
    state: NAdamWParamState
      State consiting of EMA of the gradient and gradient squared.
    grad: ndarray
      Gradient to use when computing the update.
  Returns:
    new_param: ndarray
      The next parameter value
    new_state: NAdamWParamState
      The updated state (gradient and gradient squared) value.
  """
    assert hyper_params.learning_rate is not None, "no learning rate provided."
    beta1 = hyper_params.beta1
    beta2 = hyper_params.beta2

    lr = get_cosine_learning_rate_fn(hyper_params.training_steps,
                                     hyper_params.learning_rate,
                                     hyper_params.min_learning_rate_mult,
                                     hyper_params.constant_fraction,
                                     hyper_params.warmup_fraction)(step)

    grad = grad - param * hyper_params.l2_weight_decay

    grad_sq = lax.square(grad)

    grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad

    grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

    t = step + 1.

    # correction
    if hyper_params.use_bias_correction:
        lr_t = lr * jnp.sqrt(1.0 - beta2**t) / (1.0 - beta1**t)
    else:
        lr_t = lr

    if hyper_params.use_nesterov:
        numerator = (beta1 * grad_ema + (1.0 - beta1) * grad)
        denom = jnp.sqrt(grad_sq_ema) + hyper_params.epsilon
        step = lr_t * numerator / denom
    else:
        denom = jnp.sqrt(grad_sq_ema) + hyper_params.epsilon
        step = lr_t * grad_ema / denom

    step = step + lr_t * hyper_params.adamw_weight_decay * param

    new_state = NAdamWParamState(grad_ema, grad_sq_ema)
    return -step, new_state
Beispiel #13
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        n_sma_threshhold = hyper_params.n_sma_threshhold
        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        # bias correction
        t = step + 1.
        grad_ema_corr = grad_ema / (1 - beta1 ** t)
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2 ** t)

        # RAdam update
        n_sma_inf = 2. / (1 - beta2) - 1.
        n_sma_t = n_sma_inf - (2. * t * beta2 ** t) / (1. - beta2 ** t)
        # step size computation
        step_size_num = (n_sma_t - 4.) * (n_sma_t - 2.) * n_sma_inf
        step_size_denum = (n_sma_inf - 4.) * (n_sma_inf - 2.) * n_sma_t
        # abs added to deal with negative values in first iterations (happens when the test will ignore step_size anyway)
        step_size = jnp.sqrt( jnp.abs(step_size_num / step_size_denum) )
        denom = jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps
        # update tensor computation
        update = jnp.where(n_sma_t > n_sma_threshhold, step_size * grad_ema_corr / denom, grad_ema_corr)

        new_param = param - hyper_params.learning_rate * update
        new_param -= hyper_params.learning_rate * weight_decay * param
        new_state = _RAdamParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
Beispiel #14
0
def hypot(x1, x2):
  _check_arraylike("hypot", x1, x2)
  x1, x2 = _promote_dtypes_inexact(x1, x2)
  x1 = lax.abs(x1)
  x2 = lax.abs(x2)
  x1, x2 = maximum(x1, x2), minimum(x1, x2)
  return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
Beispiel #15
0
def cross_entropy_with_logits(logits, targets, z_loss=0.0):
    """Computes cross entropy loss with stable custom gradient.

  Computes a stabilized-gradient version of:
    -jnp.sum(targets * nn.log_softmax(logits), axis=-1)

  Args:
    logits: [batch * length, num_classes] float array.
    targets: categorical one-hot targets [batch * length, num_classes] float
      array.
    z_loss: coefficient for auxilliary z-loss loss term.

  Returns:
    scalar cross-entropy loss
  """
    max_logit = logits.max(axis=-1, keepdims=True)
    shifted = logits - max_logit
    exp_shifted = jnp.exp(shifted)
    sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True)
    log_softmax = shifted - jnp.log(sum_exp)
    loss = -jnp.sum(targets * log_softmax, axis=-1)
    # Add auxilliary z-loss term.
    log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
    loss += z_loss * lax.square(log_z)
    return loss
        def quantized_layernorm(x):
            prec = hparams.quant_hparams.prec
            fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec)
            quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant,
                                                     bounds=None)

            def to_quantized(x):
                return quant_ops.to_quantized(x, dtype=dtype)

            # If epsilon is too small to represent in the quantized format, we set it
            # to the minimal representative non-zero value to avoid the possibility of
            # dividing by zero.
            fp_bounds = quantization.fp_cast.get_bounds(
                prec.exp_min, prec.exp_max, prec.sig_bits)
            epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound)
            quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype))

            # If the reciprocal of the quantized number of features is too small to
            # represent in the quantized format, we set it to the minimal
            # representative nonzero value so that the mean and variance are not
            # trivially 0.
            num_features_quantized = to_quantized(
                jnp.array(num_features, dtype=dtype))
            num_features_recip_quantized = to_quantized(
                jnp.reciprocal(num_features_quantized))
            num_features_recip_quantized = jax.lax.cond(
                jax.lax.eq(num_features_recip_quantized,
                           0.0), lambda _: quantized_epsilon,
                lambda _: num_features_recip_quantized, None)

            x_quantized = to_quantized(x)
            x_sum_quantized_reduction = quantization.quantized_sum(
                x_quantized,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sum = to_quantized(x_sum_quantized_reduction)
            mean = to_quantized(x_sum * num_features_recip_quantized)
            x_minus_mean = to_quantized(x - mean)
            x_sq = to_quantized(lax.square(x_minus_mean))
            x_sq_sum_quantized_reduction = quantization.quantized_sum(
                x_sq,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sq_sum = to_quantized(x_sq_sum_quantized_reduction)
            var = to_quantized(x_sq_sum * num_features_recip_quantized)
            # Prevent division by zero.
            var_plus_epsilon = to_quantized(var + quantized_epsilon)
            mul = to_quantized(lax.rsqrt(var_plus_epsilon))
            if self.use_scale:
                quantized_scale_param = to_quantized(scale_param)
                mul = to_quantized(mul * quantized_scale_param)
            y = to_quantized(x_minus_mean * mul)
            if self.use_bias:
                quantized_bias_param = to_quantized(bias_param)
                y = to_quantized(y + quantized_bias_param)
            return y.astype(self.dtype)
Beispiel #17
0
Datei: jet.py Projekt: 0x0is1/jax
def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.
    deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

    # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities

    c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
    tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
    tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
    for k in range(1, len(series)):
        # we know c[:k], we compute c[k]

        # propagate c to get v
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

        # propagate v to get next c

        # square
        tmp_sq[k] = fact(k) * sum(
            _scale2(k, j) * v[k - j] * v[j] for j in range(k + 1))

        # exp
        tmp_exp[k] = fact(k - 1) * sum(
            _scale(k, j) * tmp_exp[k - j] * tmp_sq[j] for j in range(1, k + 1))

        # const
        c[k] = deriv_const * tmp_exp[k]

    # we can't, and don't need, to compute c[k+1], just need to get the last v[k]
    k = len(series)
    v[k] = fact(k - 1) * sum(
        _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

    primal_out, *series_out = v
    return primal_out, series_out
Beispiel #18
0
def _cross_entropy_with_logits_fwd(logits, targets, z_loss=0.0):
    """Cross entropy loss forward pass."""
    max_logit = logits.max(axis=-1, keepdims=True)
    shifted = logits - max_logit
    exp_shifted = jnp.exp(shifted)
    sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True)
    log_softmax = shifted - jnp.log(sum_exp)
    loss = -jnp.sum(targets * log_softmax, axis=-1)
    # Add auxilliary z-loss term.
    log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
    loss += z_loss * lax.square(log_z)
    return loss, (logits, targets, z_loss, exp_shifted, sum_exp, log_softmax,
                  log_z)
Beispiel #19
0
def _atan2_taylor(primals_in, series_in):
  x, y = primals_in
  primal_out = lax.atan2(x, y)

  x, series = jet(lax.div, primals_in, series_in)
  c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out
Beispiel #20
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        # gets the optimizer parameters
        learning_rate = hyper_params.learning_rate
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        eps = hyper_params.eps
        weight_decay = hyper_params.weight_decay
        beta_lookahead = hyper_params.beta_lookahead
        lookahead_every_nth_iter = hyper_params.lookahead_every_nth_iter
        n_sma_threshhold = hyper_params.n_sma_threshhold

        # Applies gradient centralization
        grad = _gradient_centralization(grad, use_gc=hyper_params.use_gc)

        # computes exponential moving averages
        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        # bias correction
        t = step + 1.
        grad_ema_corr = grad_ema / (1 - beta1**t)
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t)

        # RAdam update
        n_sma_inf = 2. / (1 - beta2) - 1.
        n_sma_t = n_sma_inf - (2. * t * beta2**t) / (1. - beta2**t)
        # step size computation
        step_size_num = (n_sma_t - 4.) * (n_sma_t - 2.) * n_sma_inf
        step_size_denum = (n_sma_inf - 4.) * (n_sma_inf - 2.) * n_sma_t
        # abs added to deal with negative values in first iterations (happens when the test will ignore step_size anyway)
        step_size = jnp.sqrt(jnp.abs(step_size_num / step_size_denum))
        denom = jnp.sqrt(grad_sq_ema_corr) + eps
        # update tensor computation
        update = jnp.where(n_sma_t > n_sma_threshhold,
                           step_size * grad_ema_corr / denom, grad_ema_corr)

        # weight decay
        update += param * weight_decay

        # applies update
        new_param = param - update * learning_rate

        # integrated look ahead
        (new_param, lookahead_ema) = _lookahead(new_param, state.lookahead_ema,
                                                t, beta_lookahead,
                                                lookahead_every_nth_iter)

        new_state = _RangerParamState(grad_ema, grad_sq_ema, lookahead_ema)
        return new_param, new_state
 def unquantized_layernorm(x):
     num_features_recip = jnp.reciprocal(num_features)
     x_sum = jnp.sum(x, axis=-1, keepdims=True)
     mean = x_sum * num_features_recip
     x_minus_mean = x - mean
     x_sq = lax.square(x_minus_mean)
     x_sq_sum = jnp.sum(x_sq, axis=-1, keepdims=True)
     var = x_sq_sum * num_features_recip
     var_plus_epsilon = var + self.epsilon
     mul = lax.rsqrt(var_plus_epsilon)
     if self.use_scale:
         mul = mul * scale_param
     y = x_minus_mean * mul
     if self.use_bias:
         y = y + bias_param
     return y.astype(self.dtype)
Beispiel #22
0
def _log_ndtr_asymptotic_series(x, series_order):
    """Calculates the asymptotic series used in log_ndtr."""
    dtype = lax.dtype(x).type
    if series_order <= 0:
        return np.array(1, dtype)
    x_2 = lax.square(x)
    even_sum = jnp.zeros_like(x)
    odd_sum = jnp.zeros_like(x)
    x_2n = x_2  # Start with x^{2*1} = x^{2*n} with n = 1.
    for n in range(1, series_order + 1):
        y = np.array(_double_factorial(2 * n - 1), dtype) / x_2n
        if n % 2:
            odd_sum += y
        else:
            even_sum += y
        x_2n *= x_2
    return dtype(1.) + even_sum - odd_sum
Beispiel #23
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        # bias correction
        t = jnp.array(step + 1, lax.dtype(param.dtype))
        grad_ema_corr = grad_ema / (1 - beta1**t)
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t)

        denom = jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps
        new_param = param - hyper_params.learning_rate * grad_ema_corr / denom
        new_param -= hyper_params.learning_rate * weight_decay * param
        new_state = _AdamParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
Beispiel #24
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        smooth = hyper_params.smooth
        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        # bias correction
        t = step + 1.
        grad_ema_corr = grad_ema / (1 - beta1**t)
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t)

        # Sadam alternative to eps
        denom = smooth_softplus(jnp.sqrt(grad_sq_ema_corr), smooth)

        new_param = param - hyper_params.learning_rate * grad_ema_corr / denom
        new_param -= hyper_params.learning_rate * weight_decay * param
        new_state = _SadamParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
Beispiel #25
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        learning_rate = hyper_params.learning_rate
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        eps = hyper_params.eps
        weight_decay = hyper_params.weight_decay
        n_sma_threshhold = hyper_params.n_sma_threshhold
        t = step + 1.

        # RAdam step size
        n_sma_inf = 2. / (1 - beta2) - 1.
        n_sma_t = n_sma_inf - (2. * t * beta2 ** t) / (1. - beta2 ** t)
        # step size computation
        step_size_num = (n_sma_t - 4.) * (n_sma_t - 2.) * n_sma_inf
        step_size_denum = (n_sma_inf - 4.) * (n_sma_inf - 2.) * n_sma_t
        # we do use the denominator for the first iteration contrary to RAdam
        step_size = jnp.where(n_sma_t > n_sma_threshhold, jnp.sqrt(step_size_num / step_size_denum), 1.0)
        step_learning_rate = learning_rate * step_size

        # LaProp exponential moving average for grad²
        grad_sq = lax.square(grad)
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq
        # bias correction
        bias_correction2 = 1 - beta2 ** t
        grad_sq_ema_corr = grad_sq_ema / bias_correction2

        # LaProp exponential moving average for update tensor
        update = grad / (jnp.sqrt(grad_sq_ema_corr) + eps)
        update_ema = beta1 * state.update_ema + (1. - beta1) * learning_rate * update
        # bias correction
        bias_correction1 = beta1 * state.bias_correction1 + (1 - beta1) * learning_rate
        update_ema_corr = update_ema / bias_correction1

        new_param = param - step_learning_rate * update_ema_corr
        new_param -= learning_rate * weight_decay * param
        new_state = _RLaPropParamState(update_ema, grad_sq_ema, bias_correction1)
        return new_param, new_state
Beispiel #26
0
def _var(a,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         dtype=None,
         out=None,
         ddof=0,
         keepdims=False,
         *,
         where=None):
    _check_arraylike("var", a)
    lax_internal._check_user_dtype_supported(dtype, "var")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.var is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
    centered = lax.sub(a, a_mean)
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)
    normalizer = normalizer - ddof

    result = sum(centered, axis, keepdims=keepdims, where=where)
    out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
    return lax.convert_element_type(out, dtype)
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        # bias correction
        t = step + 1.
        grad_ema_corr = grad_ema / (1 - beta1 ** t)
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2 ** t)

        # learning rate warmup, to deal with unstability during first iterations
        exponential_warmup = 1. - jnp.exp( - (1. - beta2) * t )
        linear_warmup = jnp.minimum(1., 0.5 * (1. - beta2) * t)
        learning_rate = hyper_params.learning_rate * jnp.where(hyper_params.use_exponential_warmup, exponential_warmup, linear_warmup)

        denom = jnp.sqrt(grad_sq_ema_corr) + hyper_params.eps
        new_param = param - learning_rate * grad_ema_corr / denom
        new_param -= learning_rate * weight_decay * param
        new_state = _RAdamSimplifiedParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
Beispiel #28
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        delta = hyper_params.delta
        wd_ratio = hyper_params.wd_ratio
        eps = hyper_params.eps
        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq

        # bias correction
        t = step + 1.
        grad_ema_corr = grad_ema / (1 - beta1**t)
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2**t)

        # compute update
        denom = jnp.sqrt(grad_sq_ema_corr) + eps
        update = grad_ema_corr / denom

        # Projection
        if len(param.shape) > 1:
            update, wd_ratio = _project_scale_invariant(
                update, param, grad, delta, wd_ratio, eps)
        else:
            wd_ratio = 1.

        # weight decay
        new_param = param * (
            1. - hyper_params.learning_rate * weight_decay * wd_ratio)

        # update step
        new_param -= hyper_params.learning_rate * update
        new_state = _AdamPParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
Beispiel #29
0
    def apply(self,
              x,
              batch_stats=None,
              use_running_average=False,
              axis=-1,
              momentum=0.99,
              epsilon=1e-5,
              dtype=jnp.float32,
              bias=True,
              scale=True,
              bias_init=initializers.zeros,
              scale_init=initializers.ones,
              axis_name=None):
        """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats
        will be used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of
        the batch statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  if True, bias (beta) is added.
      scale: if True, multiply by scale (gamma).
        When the next layer is linear (also e.g. nn.relu), this can be disabled
        since the scaling will be done by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).

    Returns:
      Normalized inputs (this same shape as inputs).
    """
        x = jnp.asarray(x, jnp.float32)
        axis = axis if isinstance(axis, tuple) else (axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
        if self.is_stateful() or batch_stats:
            ra_mean = self.state('mean',
                                 reduced_feature_shape,
                                 initializers.zeros,
                                 collection=batch_stats)
            ra_var = self.state('var',
                                reduced_feature_shape,
                                initializers.ones,
                                collection=batch_stats)
        else:
            ra_mean = None
            ra_var = None

        if use_running_average:
            if ra_mean is None:
                raise ValueError('batch_stats should be provided if '
                                 'use_running_averages is True')
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            if axis_name is not None and not self.is_initializing():
                mean = lax.pmean(mean, axis_name=axis_name)

            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if axis_name is not None and not self.is_initializing():
                mean2 = lax.pmean(mean2, axis_name=axis_name)
            var = mean2 - lax.square(mean)

            if ra_mean and not self.is_initializing():
                ra_mean.value = momentum * ra_mean.value + (1 -
                                                            momentum) * mean
                ra_var.value = momentum * ra_var.value + (1 - momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + epsilon)
        if scale:
            mul = mul * self.param('scale', reduced_feature_shape,
                                   scale_init).reshape(feature_shape)
        y = y * mul
        if bias:
            y = y + self.param('bias', reduced_feature_shape,
                               bias_init).reshape(feature_shape)
        return jnp.asarray(y, dtype)
Beispiel #30
0
def _norm_logpdf(x):
    neg_half = _constant_like(x, -0.5)
    log_normalizer = _constant_like(x, _norm_logpdf_constant)
    return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)