コード例 #1
0
def isfinite(x):
    _check_arraylike("isfinite", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.floating):
        return lax.is_finite(x)
    elif dtypes.issubdtype(dtype, np.complexfloating):
        return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
    else:
        return lax.full_like(x, True, dtype=np.bool_)
コード例 #2
0
ファイル: special.py プロジェクト: xeransis/jax
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
コード例 #3
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
コード例 #4
0
ファイル: dynamic_scale.py プロジェクト: yanndupis/flax
        def grad_fn_wrapper(*args):
            aux, grad = grad_fn(*args)
            aux = (aux[0] / self.scale,
                   aux[1]) if has_aux else aux / self.scale

            grad = jax.tree_map(
                lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad)
            if axis_name is not None:
                grad = lax.pmean(grad, axis_name)

            finite = jnp.array(True)
            for g in jax.tree_leaves(grad):
                finite &= jnp.all(lax.is_finite(g))

            grow = self.fin_steps == self.growth_interval
            fin_scale = jnp.where(grow & finite,
                                  self.scale * self.growth_factor, self.scale)
            inf_scale = self.scale * self.backoff_factor
            new_scale = jnp.where(finite, fin_scale, inf_scale)
            new_fin_steps = jnp.where(grow | (~finite), 0, self.fin_steps + 1)

            new_self = self.replace(fin_steps=new_fin_steps, scale=new_scale)
            return DynamicScaleResult(new_self, finite, aux, grad)
コード例 #5
0
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context):
    """Normalizes attention."""
    a = attn_weights

    def unquantized_softmax(a):
        a = lax.exp(
            a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True))
        return a.astype(dtype)

    # Quantize intermediate activations with QuantOps.
    # Currently only supports unscaled floating-point formats.
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
                                              fp_spec=quant_hparams.prec)
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,
                                                 bounds=None)

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
            a_exp,
            axis=norm_dims,
            keepdims=True,
            prec=quant_hparams.reduction_prec)
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,
                                         dtype=dtype)

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
                                             dtype=dtype)
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)

    # If no params, return accurate Softmax.
    if softmax_hparams == SoftmaxHParams(None, None,
                                         None) or softmax_hparams is None:
        return unquantized_softmax(a)

    # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for
    # the entire training run, even before the global activation start step.
    if softmax_hparams.quant_hparams is not None:
        return lax.cond(quant_context.quantize_acts, quantized_softmax,
                        unquantized_softmax, a)

    # Approximated Softmax
    exp_hparams = softmax_hparams.exp_hparams
    recip_hparams = softmax_hparams.reciprocal_hparams

    # Substract max value from dimensions to be normalized.
    shape = jax.util.subvals(onp.shape(a),
                             zip(norm_dims, (1, ) * len(norm_dims)))
    dimadd = lambda x: lax.reshape(x, shape)
    # pylint: disable=protected-access
    amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max,
                      norm_dims)
    amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))
    amax_singletons = dimadd(amax)
    asubmax = lax.sub(a, amax_singletons)

    # Calculate approximated exponential
    approx_exp = exponential(asubmax, dtype, exp_hparams)

    # If sum_high_bound: Upper clip bound for sum(exp(x-M)).
    asumexp = dimadd(
        lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add,
                   norm_dims))

    if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0:
        sum_low_bound = 1.
        if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract:
            sum_low_bound = 1 - onp.exp(exp_hparams.low_bound)
        asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound)

    # Approximation of reciprocal.
    arecip = reciprocal(asumexp, dtype, recip_hparams)
    return lax.mul(approx_exp, arecip).astype(dtype)