def isfinite(x): _check_arraylike("isfinite", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.is_finite(x) elif dtypes.issubdtype(dtype, np.complexfloating): return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) else: return lax.full_like(x, True, dtype=np.bool_)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = jnp.broadcast_arrays(a, b) dims = _reduction_dims(a, axis) dimadd = lambda x: lax.expand_dims(x, dims) amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_singletons = dimadd(amax) if b is None: out = lax.add( lax.log( lax.reduce(lax.exp(lax.sub(a, amax_singletons)), _constant_like(a, 0), lax.add, dims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b), _constant_like(a, 0), lax.add, dims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (dimadd(out), dimadd(sign)) if keepdims else (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return dimadd(out) if keepdims else out
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) if b is None: out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b), axis=dims, keepdims=keepdims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return out
def grad_fn_wrapper(*args): aux, grad = grad_fn(*args) aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale grad = jax.tree_map( lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad) if axis_name is not None: grad = lax.pmean(grad, axis_name) finite = jnp.array(True) for g in jax.tree_leaves(grad): finite &= jnp.all(lax.is_finite(g)) grow = self.fin_steps == self.growth_interval fin_scale = jnp.where(grow & finite, self.scale * self.growth_factor, self.scale) inf_scale = self.scale * self.backoff_factor new_scale = jnp.where(finite, fin_scale, inf_scale) new_fin_steps = jnp.where(grow | (~finite), 0, self.fin_steps + 1) new_self = self.replace(fin_steps=new_fin_steps, scale=new_scale) return DynamicScaleResult(new_self, finite, aux, grad)
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context): """Normalizes attention.""" a = attn_weights def unquantized_softmax(a): a = lax.exp( a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True)) return a.astype(dtype) # Quantize intermediate activations with QuantOps. # Currently only supports unscaled floating-point formats. def quantized_softmax(a): # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing # intermediate values. Note this differs from the log-domain # implementation of softmax used above. quant_hparams = softmax_hparams.quant_hparams fp_quant_config = QuantOps.FloatQuant(is_scaled=False, fp_spec=quant_hparams.prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config, bounds=None) a = quant_ops.to_quantized(a, dtype=dtype) # Note that the max of a quantized vector is necessarily also quantized to # the same precision since the max of a vector must be an existing element # of the vector, so we don't need to explicitly insert a quantization # operator to the output of the max reduction. a_max = jnp.max(a, axis=norm_dims, keepdims=True) a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype) a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype) sum_exp_quantized_reduction = quantization.quantized_sum( a_exp, axis=norm_dims, keepdims=True, prec=quant_hparams.reduction_prec) sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction, dtype=dtype) inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp), dtype=dtype) a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype) return a_softmax.astype(dtype) # If no params, return accurate Softmax. if softmax_hparams == SoftmaxHParams(None, None, None) or softmax_hparams is None: return unquantized_softmax(a) # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for # the entire training run, even before the global activation start step. if softmax_hparams.quant_hparams is not None: return lax.cond(quant_context.quantize_acts, quantized_softmax, unquantized_softmax, a) # Approximated Softmax exp_hparams = softmax_hparams.exp_hparams recip_hparams = softmax_hparams.reciprocal_hparams # Substract max value from dimensions to be normalized. shape = jax.util.subvals(onp.shape(a), zip(norm_dims, (1, ) * len(norm_dims))) dimadd = lambda x: lax.reshape(x, shape) # pylint: disable=protected-access amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max, norm_dims) amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)) amax_singletons = dimadd(amax) asubmax = lax.sub(a, amax_singletons) # Calculate approximated exponential approx_exp = exponential(asubmax, dtype, exp_hparams) # If sum_high_bound: Upper clip bound for sum(exp(x-M)). asumexp = dimadd( lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add, norm_dims)) if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0: sum_low_bound = 1. if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract: sum_low_bound = 1 - onp.exp(exp_hparams.low_bound) asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound) # Approximation of reciprocal. arecip = reciprocal(asumexp, dtype, recip_hparams) return lax.mul(approx_exp, arecip).astype(dtype)