def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = jnp.broadcast_arrays(a, b) dims = _reduction_dims(a, axis) dimadd = lambda x: lax.expand_dims(x, dims) amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_singletons = dimadd(amax) if b is None: out = lax.add( lax.log( lax.reduce(lax.exp(lax.sub(a, amax_singletons)), _constant_like(a, 0), lax.add, dims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b), _constant_like(a, 0), lax.add, dims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (dimadd(out), dimadd(sign)) if keepdims else (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return dimadd(out) if keepdims else out
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) if b is None: out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) sign = jnp.where(out == -np.inf, 0.0, sign) else: sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b), axis=dims, keepdims=keepdims) sign = lax.stop_gradient(lax.sign(sumexp)) out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: out = jnp.where(sign < 0, np.nan, out) return out
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out