Beispiel #1
0
def reciprocal(tensor, dtype, recip_hparams):
    """Generates a reciprocal function based on recip hyper params."""
    if recip_hparams is not None and recip_hparams.linear_gradient != 0:
        # Want: max(low_bound, -a*x+b) such that (-a*x+b) goes through
        # (1, 1)
        # Solution: max(low_bound, a+1- a*x) for arbitrary a>0.
        afull = jnp.full(tensor.shape,
                         recip_hparams.linear_gradient).astype(dtype)
        aplus1full = jnp.full(tensor.shape,
                              1 + recip_hparams.linear_gradient).astype(dtype)
        arecip = jnp.clip(lax.sub(aplus1full, lax.mul(afull, tensor)),
                          recip_hparams.low_bound, 1.).astype(dtype)
    else:
        arecip = lax.reciprocal(tensor)
    return arecip
Beispiel #2
0
def correlators(g, alpha, xs):
    """Populates the supplied vector with correlator values, assuming the first
  three entries have been populated.

  """
    n = xs.shape[0]

    if n <= 2:
        return xs

    g_recip = jl.reciprocal(g)

    def run(k, xs):
        new_v = t_k_plus_3(k - 3, alpha, g_recip, xs)
        return j.ops.index_update(xs, k, new_v)

    return jl.fori_loop(3, n + 1, run, xs)
Beispiel #3
0
def xlog1py_jvp_rhs(g, x, y, jaxpr, aval, consts):
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
    g, x = _promote_args_like(osp_special.xlog1py, g, x)
    jac = lax._safe_mul(lax._brcast(x, y), lax._brcast(lax.reciprocal(1 + y),
                                                       x))
    return lax.mul(lax._brcast(g, jac), jac)
    def __call__(self, x, training: bool):
        """Normalizes the input using batch statistics.
        Args:
            x: the input to be normalized.
        Returns:
            Normalized inputs (the same shape as inputs).
        """
        x = jnp.asarray(x, self.dtype)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # we detect if we're in initialization via empty variable tree.
        initializing = not self.has_variable('batch_stats', 'mean')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if not training:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            var = jnp.mean(lax.abs(x - mean),
                           axis=reduction_axis,
                           keepdims=False) * jnp.sqrt(jnp.pi / 2)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, var])
                mean, var = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        mean = jnp.asarray(mean, self.dtype)
        var = jnp.asarray(var, self.dtype)
        y = x - mean.reshape(feature_shape)
        mul = lax.reciprocal(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            scale = jnp.asarray(scale, self.dtype)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias
        return jnp.asarray(y, self.dtype)