Esempio n. 1
0
 def _stats_to_bounds(self, stats_value):
     """Computes activation clipping bounds from activation statistics."""
     hyper = self.hyper
     maximum = stats_value.mean_batch_maximum
     minimum = stats_value.mean_batch_minimum
     mom = jnp.maximum(jnp.abs(maximum), jnp.abs(minimum))
     stddev_uncentered = lax.sqrt(stats_value.mean_sq)
     absdev_uncentered = stats_value.mean_abs
     stddev = lax.sqrt(stats_value.mean_sq - stats_value.mean**2)
     abs_mean = jnp.abs(stats_value.mean)
     if hyper.use_old_code:  # old code of computing the bound
         if hyper.use_cams:  # upper confidence bound formula
             return abs_mean + hyper.stddev_coeff * stddev
         elif hyper.use_mean_of_max:
             return mom
         else:
             return (
                 hyper.mix_coeff * hyper.stddev_coeff * stddev_uncentered +
                 (1 - hyper.mix_coeff) * hyper.absdev_coeff *
                 absdev_uncentered)
     else:  # use new way of computing the bound
         cams = abs_mean + hyper.cams_stddev_coeff * stddev
         return (hyper.fixed_bound + hyper.mean_of_max_coeff * mom +
                 hyper.stddev_coeff * stddev_uncentered +
                 hyper.absdev_coeff * absdev_uncentered +
                 hyper.cams_coeff * cams)
Esempio n. 2
0
File: radam.py Progetto: rwbfd/flax
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        lr = hyper_params.learning_rate
        beta1 = hyper_params.beta1
        beta2 = hyper_params.beta2
        weight_decay = hyper_params.weight_decay
        t = step + 1.
        rho_inf = 2.0 / (1 - beta2) - 1

        grad_sq = lax.square(grad)
        grad_ema = beta1 * state.grad_ema + (1. - beta1) * grad
        grad_sq_ema = beta2 * state.grad_sq_ema + (1. - beta2) * grad_sq
        beta2_t = beta2**5
        grad_sq_ema_corr = grad_sq_ema / (1 - beta2_t)
        rho_t = rho_inf - 2.0 * t * beta2_t / (1 - beta2_t)
        if rho_t <= 5:
            step_size = 1.0 / (1 - beta1**t)
        else:
            step_size = lax.sqrt(
                (1 - beta2_t) * (rho_t - 4) / (rho_inf - 4) *
                (rho_t - 2) / rho_t * rho_inf / (rho_inf - 2)) / (1 - beta1**t)

        if rho_t <= 5:
            new_param = param - lr * step_size * grad_ema
            new_param -= lr * weight_decay * param
        else:
            denom = lax.sqrt(grad_sq_ema_corr) + hyper_params.eps
            new_param = param - lr * step_size * grad_ema / denom
            new_param -= lr * weight_decay * param
        new_state = _RAdamParamState(grad_ema, grad_sq_ema)
        return new_param, new_state
 def _stats_to_bounds(self, stats_value):
   """Computes activation clipping bounds from activation statistics."""
   hyper = self.hyper
   if hyper.use_cams:  # upper confidence bound formula
     return jnp.abs(stats_value.mean) + hyper.stddev_coeff * lax.sqrt(
         stats_value.mean_sq - stats_value.mean**2)
   elif hyper.use_mean_of_max:
     maximum = stats_value.mean_batch_maximum
     minimum = stats_value.mean_batch_minimum
     return jnp.maximum(jnp.abs(maximum), jnp.abs(minimum))
   else:
     stddev_uncentered = lax.sqrt(stats_value.mean_sq)
     absdev_uncentered = stats_value.mean_abs
     return (hyper.mix_coeff * hyper.stddev_coeff * stddev_uncentered +
             (1 - hyper.mix_coeff) * hyper.absdev_coeff * absdev_uncentered)
Esempio n. 4
0
def _dct_ortho_norm(out, axis):
    factor = lax.concatenate([
        lax.full((1, ), 4, out.dtype),
        lax.full((out.shape[axis] - 1, ), 2, out.dtype)
    ], 0)
    factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
    return out / lax.sqrt(factor * out.shape[axis])
Esempio n. 5
0
def hypot(x1, x2):
  _check_arraylike("hypot", x1, x2)
  x1, x2 = _promote_dtypes_inexact(x1, x2)
  x1 = lax.abs(x1)
  x2 = lax.abs(x2)
  x1, x2 = maximum(x1, x2), minimum(x1, x2)
  return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
Esempio n. 6
0
def _sqrt2(x):
    x, xx = x
    c = lax.sqrt(x)
    u, uu = _mul12(c, c)
    cc = (x - u - uu + xx) * 0.5 / c
    y = c + cc
    yy = c - y + cc
    return y, yy
Esempio n. 7
0
def batchnorm(x, s, bias, mean, var, epsilon=1e-5):
    dims_x = len(x.shape)
    dim_ones = (1,) * (dims_x - 2)
    s = s.reshape(-1, *dim_ones)
    bias = bias.reshape(-1, *dim_ones)
    mean = mean.reshape(-1, *dim_ones)
    var = var.reshape(-1, *dim_ones)
    ot = s * (x - mean) / lax.sqrt(var + epsilon) + bias
    return ot
Esempio n. 8
0
def nanstd(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanstd", a)
    lax_internal._check_user_dtype_supported(dtype, "nanstd")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanstd is not supported.")
    return lax.sqrt(
        nanvar(a,
               axis=axis,
               dtype=dtype,
               ddof=ddof,
               keepdims=keepdims,
               where=where))
Esempio n. 9
0
def _ndtri(p):
    """Implements ndtri core logic."""

    # Constants used in piece-wise rational approximations. Taken from the cephes
    # library:
    # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    p0 = list(
        reversed([
            -5.99633501014107895267E1, 9.80010754185999661536E1,
            -5.66762857469070293439E1, 1.39312609387279679503E1,
            -1.23916583867381258016E0
        ]))
    q0 = list(
        reversed([
            1.0, 1.95448858338141759834E0, 4.67627912898881538453E0,
            8.63602421390890590575E1, -2.25462687854119370527E2,
            2.00260212380060660359E2, -8.20372256168333339912E1,
            1.59056225126211695515E1, -1.18331621121330003142E0
        ]))
    p1 = list(
        reversed([
            4.05544892305962419923E0, 3.15251094599893866154E1,
            5.71628192246421288162E1, 4.40805073893200834700E1,
            1.46849561928858024014E1, 2.18663306850790267539E0,
            -1.40256079171354495875E-1, -3.50424626827848203418E-2,
            -8.57456785154685413611E-4
        ]))
    q1 = list(
        reversed([
            1.0, 1.57799883256466749731E1, 4.53907635128879210584E1,
            4.13172038254672030440E1, 1.50425385692907503408E1,
            2.50464946208309415979E0, -1.42182922854787788574E-1,
            -3.80806407691578277194E-2, -9.33259480895457427372E-4
        ]))
    p2 = list(
        reversed([
            3.23774891776946035970E0, 6.91522889068984211695E0,
            3.93881025292474443415E0, 1.33303460815807542389E0,
            2.01485389549179081538E-1, 1.23716634817820021358E-2,
            3.01581553508235416007E-4, 2.65806974686737550832E-6,
            6.23974539184983293730E-9
        ]))
    q2 = list(
        reversed([
            1.0, 6.02427039364742014255E0, 3.67983563856160859403E0,
            1.37702099489081330271E0, 2.16236993594496635890E-1,
            1.34204006088543189037E-2, 3.28014464682127739104E-4,
            2.89247864745380683936E-6, 6.79019408009981274425E-9
        ]))

    dtype = lax.dtype(p).type
    shape = jnp.shape(p)

    def _create_polynomial(var, coeffs):
        """Compute n_th order polynomial via Horner's method."""
        coeffs = np.array(coeffs, dtype)
        if not coeffs.size:
            return jnp.zeros_like(var)
        return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var

    maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p)
    # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
    # later on. The result from the computation when p == 0 is not used so any
    # number that doesn't result in NaNs is fine.
    sanitized_mcp = jnp.where(maybe_complement_p <= dtype(0.),
                              jnp.full(shape, dtype(0.5)), maybe_complement_p)

    # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
    w = sanitized_mcp - dtype(0.5)
    ww = lax.square(w)
    x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) /
                                _create_polynomial(ww, q0))
    x_for_big_p *= -dtype(np.sqrt(2. * np.pi))

    # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
    # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
    # arrays based on whether p < exp(-32).
    z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
    first_term = z - lax.log(z) / z
    second_term_small_p = (_create_polynomial(dtype(1.) / z, p2) /
                           _create_polynomial(dtype(1.) / z, q2) / z)
    second_term_otherwise = (_create_polynomial(dtype(1.) / z, p1) /
                             _create_polynomial(dtype(1.) / z, q1) / z)
    x_for_small_p = first_term - second_term_small_p
    x_otherwise = first_term - second_term_otherwise

    x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)), x_for_big_p,
                  jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise))

    x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x)
    infinity = jnp.full(shape, dtype(np.inf))
    x_nan_replaced = jnp.where(p <= dtype(0.0), -infinity,
                               jnp.where(p >= dtype(1.0), infinity, x))
    return x_nan_replaced
Esempio n. 10
0
File: jet.py Progetto: 0x0is1/jax
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))


def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)
Esempio n. 11
0
  def apply(self,
            x,
            layer=LAYER_EVONORM_B0,
            nonlinearity=True,
            num_groups=32,
            group_size=None,
            batch_stats=None,
            use_running_average=False,
            axis=-1,
            momentum=0.99,
            epsilon=1e-5,
            dtype=jnp.float32,
            bias=True,
            scale=True,
            bias_init=initializers.zeros,
            scale_init=initializers.ones,
            axis_name=None,
            axis_index_groups=None):
    """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.
      layer: LAYER_EVONORM_B0 or LAYER_EVONORM_S0.
      nonlinearity: use the EvoNorm nonlinearity.
      num_groups: number of groups to use for group statistics.
      group_size: size of groups, see nn.GroupNorm.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats will be
        used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of the batch
        statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  if True, bias (beta) is added.
      scale: if True, multiply by scale (gamma). When the next layer is linear
        (also e.g. nn.relu), this can be disabled since the scaling will be done
        by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).
      axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For
          example, `[[0, 1], [2, 3]]` would independently batch-normalize over
          the examples on the first two and last two devices. See `jax.lax.psum`
          for more details.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
    x = jnp.asarray(x, jnp.float32)

    axis = axis if isinstance(axis, tuple) else (axis,)
    # pylint: disable=protected-access
    axis = nn.normalization._absolute_dims(x.ndim, axis)
    # pylint: enable=protected-access
    feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
    reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
    instance_reduction_axis = tuple(
        i for i in range(x.ndim) if i not in axis and i > 0)
    batch_reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

    if nonlinearity:
      v = self.param('v', reduced_feature_shape,
                     jax.nn.initializers.ones).reshape(feature_shape)
      if layer == LAYER_EVONORM_S0:
        den, group_shape, input_shape = _GroupStd(
            x,
            num_groups=num_groups,
            group_size=group_size,
            epsilon=epsilon,
            dtype=dtype,
        )
        x = x * nn.sigmoid(v * x)
        x = x.reshape(group_shape)
        x /= den
        x = x.reshape(input_shape)
      elif layer == LAYER_EVONORM_B0:
        if self.is_stateful() or batch_stats:
          ra_var = self.state(
              'var',
              reduced_feature_shape,
              initializers.ones,
              collection=batch_stats)
        else:
          ra_var = None

        if use_running_average:
          if ra_var is None:
            raise ValueError(
                'when use_running_averages is True '
                'either use a stateful context or provide batch_stats')
          var = ra_var.value
        else:
          mean = jnp.mean(x, axis=batch_reduction_axis, keepdims=False)
          mean2 = jnp.mean(
              lax.square(x), axis=batch_reduction_axis, keepdims=False)
          if axis_name is not None and not self.is_initializing():
            concatenated_mean = jnp.concatenate([mean, mean2])
            mean, mean2 = jnp.split(
                lax.pmean(
                    concatenated_mean,
                    axis_name=axis_name,
                    axis_index_groups=axis_index_groups), 2)
          var = mean2 - lax.square(mean)

          if ra_var and not self.is_initializing():
            ra_var.value = momentum * ra_var.value + (1 - momentum) * var

        left = lax.sqrt(var + epsilon)

        instance_std = jnp.sqrt(
            x.var(axis=instance_reduction_axis, keepdims=True) + epsilon)
        right = v * x + instance_std
        x = x / jnp.maximum(left, right)
      else:
        raise ValueError('Unknown EvoNorm layer: {}'.format(layer))

    if scale:
      x *= self.param('scale', reduced_feature_shape,
                      scale_init).reshape(feature_shape)
    if bias:
      x = x + self.param('bias', reduced_feature_shape,
                         bias_init).reshape(feature_shape)
    return jnp.asarray(x, dtype)
Esempio n. 12
0
  def apply(
      self,
      x,
      num_groups=32,
      group_size=None,
      epsilon=1e-6,
      dtype=jnp.float32,
  ):
    """Applies group normalization to the input (arxiv.org/abs/1803.08494).

    This op is similar to batch normalization, but statistics are shared across
    equally-sized groups of channels and not shared across batch dimension.
    Thus, group normalization does not depend on the batch composition and does
    not require maintaining internal state for storing statistics.

    The user should either specify the total number of channel groups or the
    number of channels per group.

    Args:
      x: the input of shape N...C, where N is a batch dimension and C is a
        channels dimensions. `...` represents an arbitrary number of extra
        dimensions that are used to accumulate statistics over.
      num_groups: the total number of channel groups. The default value of 32 is
        proposed by the original group normalization paper.
      group_size: the number of channels in a group.
      epsilon: A small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).

    Returns:
      Normalized inputs (the same shape as inputs).

    """
    x = jnp.asarray(x, jnp.float32)
    if ((num_groups is None and group_size is None) or
        (num_groups is not None and group_size is not None)):
      raise ValueError('Either `num_groups` or `group_size` should be '
                       'specified, but not both of them.')

    channels = x.shape[-1]
    if group_size is not None:
      if channels % group_size != 0:
        raise ValueError('Number of channels ({}) is not multiple of the '
                         'group size ({}).'.format(channels, group_size))
      num_groups = channels // group_size
    while num_groups > 1:
      if channels % num_groups == 0:
        break
      num_groups -= 1

    group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups)

    input_shape = x.shape
    x = x.reshape(group_shape)

    reduction_axis = list(range(1, x.ndim - 2)) + [x.ndim - 1]

    mean = jnp.mean(x, axis=reduction_axis, keepdims=True)
    mean_of_squares = jnp.mean(
        jnp.square(x), axis=reduction_axis, keepdims=True)
    var = mean_of_squares - jnp.square(mean)

    std = lax.sqrt(var + epsilon)

    return std.astype(dtype), group_shape, input_shape
Esempio n. 13
0
  def apply(self,
            x,
            batch_stats=None,
            use_running_average=False,
            axis=-1,
            momentum=0.99,
            epsilon=1e-5,
            dtype=jnp.float32,
            axis_name=None,
            axis_index_groups=None):
    """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats will be
        used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of the batch
        statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).
      axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For
          example, `[[0, 1], [2, 3]]` would independently batch-normalize over
          the examples on the first two and last two devices. See `jax.lax.psum`
          for more details.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
    x = jnp.asarray(x, jnp.float32)
    axis = axis if isinstance(axis, tuple) else (axis,)
    # pylint: disable=protected-access
    axis = nn.normalization._absolute_dims(x.ndim, axis)
    # pylint: enable=protected-access
    reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
    reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
    if self.is_stateful() or batch_stats:
      ra_var = self.state(
          'var',
          reduced_feature_shape,
          initializers.ones,
          collection=batch_stats)
    else:
      ra_var = None

    if use_running_average:
      if ra_var is None:
        raise ValueError('when use_running_averages is True '
                         'either use a stateful context or provide batch_stats')
      var = ra_var.value
    else:
      mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
      mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
      if axis_name is not None and not self.is_initializing():
        concatenated_mean = jnp.concatenate([mean, mean2])
        mean, mean2 = jnp.split(
            lax.pmean(
                concatenated_mean,
                axis_name=axis_name,
                axis_index_groups=axis_index_groups), 2)
      var = mean2 - lax.square(mean)

      if ra_var and not self.is_initializing():
        ra_var.value = momentum * ra_var.value + (1 - momentum) * var

    mul = lax.sqrt(var + epsilon)

    return jnp.asarray(mul, dtype)
Esempio n. 14
0
 def __call__(self, x):
     #(1 + x/np.sqrt(4 + x**2))/2
     return lax.mul(
         0.5, lax.add(lax.div(x, lax.sqrt(lax.add(lax.square(x), 4.))), 1.))
Esempio n. 15
0
 def __call__(self, x):
     #(x + np.sqrt(x**2 + 4))/2
     return lax.mul(0.5, lax.add(x, lax.sqrt(lax.add(lax.square(x), 4.))))
Esempio n. 16
0
def squareplus(x):
    return lax.mul(0.5, lax.add(x, lax.sqrt(lax.add(lax.square(x), 4.0))))