def reciprocal(tensor, dtype, recip_hparams): """Generates a reciprocal function based on recip hyper params.""" if recip_hparams is not None and recip_hparams.linear_gradient != 0: # Want: max(low_bound, -a*x+b) such that (-a*x+b) goes through # (1, 1) # Solution: max(low_bound, a+1- a*x) for arbitrary a>0. afull = jnp.full(tensor.shape, recip_hparams.linear_gradient).astype(dtype) aplus1full = jnp.full(tensor.shape, 1 + recip_hparams.linear_gradient).astype(dtype) arecip = jnp.clip(lax.sub(aplus1full, lax.mul(afull, tensor)), recip_hparams.low_bound, 1.).astype(dtype) else: arecip = lax.reciprocal(tensor) return arecip
def correlators(g, alpha, xs): """Populates the supplied vector with correlator values, assuming the first three entries have been populated. """ n = xs.shape[0] if n <= 2: return xs g_recip = jl.reciprocal(g) def run(k, xs): new_v = t_k_plus_3(k - 3, alpha, g_recip, xs) return j.ops.index_update(xs, k, new_v) return jl.fori_loop(3, n + 1, run, xs)
def xlog1py_jvp_rhs(g, x, y, jaxpr, aval, consts): x, y = _promote_args_like(osp_special.xlog1py, x, y) g, x = _promote_args_like(osp_special.xlog1py, g, x) jac = lax._safe_mul(lax._brcast(x, y), lax._brcast(lax.reciprocal(1 + y), x)) return lax.mul(lax._brcast(g, jac), jac)
def __call__(self, x, training: bool): """Normalizes the input using batch statistics. Args: x: the input to be normalized. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, self.dtype) axis = self.axis if isinstance(self.axis, tuple) else (self.axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) # we detect if we're in initialization via empty variable tree. initializing = not self.has_variable('batch_stats', 'mean') ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) if not training: mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) var = jnp.mean(lax.abs(x - mean), axis=reduction_axis, keepdims=False) * jnp.sqrt(jnp.pi / 2) if self.axis_name is not None and not initializing: concatenated_mean = jnp.concatenate([mean, var]) mean, var = jnp.split( lax.pmean(concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) if not initializing: ra_mean.value = self.momentum * ra_mean.value + ( 1 - self.momentum) * mean ra_var.value = self.momentum * ra_var.value + ( 1 - self.momentum) * var mean = jnp.asarray(mean, self.dtype) var = jnp.asarray(var, self.dtype) y = x - mean.reshape(feature_shape) mul = lax.reciprocal(var + self.epsilon) if self.use_scale: scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) scale = jnp.asarray(scale, self.dtype) mul = mul * scale y = y * mul if self.use_bias: bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) bias = jnp.asarray(bias, self.dtype) y = y + bias return jnp.asarray(y, self.dtype)