Beispiel #1
0
def hypot(x1, x2):
  _check_arraylike("hypot", x1, x2)
  x1, x2 = _promote_dtypes_inexact(x1, x2)
  x1 = lax.abs(x1)
  x2 = lax.abs(x2)
  x1, x2 = maximum(x1, x2), minimum(x1, x2)
  return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
Beispiel #2
0
def assert_func(error: Error, pred: Bool, msg: str,
                payload: Optional[Payload]) -> Error:
    code = next_code()
    payload = init_payload if payload is None else payload
    out_err = error.err | jnp.logical_not(pred)
    out_code = lax.select(error.err, error.code, code)
    out_payload = lax.select(error.err, error.payload, payload)
    return Error(out_err, out_code, {code: msg, **error.msgs}, out_payload)
Beispiel #3
0
def _float_divmod(x1, x2):
    # see float_divmod in floatobject.c of CPython
    mod = lax.rem(x1, x2)
    div = lax.div(lax.sub(x1, mod), x2)

    ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
    mod = lax.select(ind, mod + x2, mod)
    div = lax.select(ind, div - _constant_like(div, 1), div)

    return lax.round(div), mod
Beispiel #4
0
def _ndtr(x):
    """Implements ndtr core logic."""
    dtype = lax.dtype(x).type
    half_sqrt_2 = dtype(0.5) * np.sqrt(2., dtype=dtype)
    w = x * half_sqrt_2
    z = lax.abs(w)
    y = lax.select(
        lax.lt(z, half_sqrt_2),
        dtype(1.) + lax.erf(w),
        lax.select(lax.gt(w, dtype(0.)),
                   dtype(2.) - lax.erfc(z), lax.erfc(z)))
    return dtype(0.5) * y
Beispiel #5
0
 def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory):
   rng = rng_factory(self.rng())
   pred = rng(pred_shape, onp.bool_)
   on_true = rng(arg_shape, dtype)
   on_false = rng(arg_shape, dtype)
   select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
   check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)
Beispiel #6
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
Beispiel #7
0
 def __call__(self,
              inputs_q: Array,
              inputs_kv: Array,
              mask: Optional[Array] = None):
     query = self.dense(name="query")(inputs_q)
     key = self.dense(name="key")(inputs_kv)
     value = self.dense(name="value")(inputs_kv)
     if mask is not None:
         attention_bias = lax.select(
             mask > 0,
             jnp.full(mask.shape, 0).astype(self.dtype),
             jnp.full(mask.shape, -1e10).astype(self.dtype))
     else:
         attention_bias = None
     dropout_rng = None
     if not self.deterministic and self.dropout_rate > 0:
         dropout_rng = self.make_rng("dropout")
     x = nn.attention.dot_product_attention(
         query,
         key,
         value,
         bias=attention_bias,
         dropout_rng=dropout_rng,
         dropout_rate=self.dropout_rate,
         deterministic=self.deterministic,
         dtype=self.dtype)
     output = nn.DenseGeneral(features=self.out_features,
                              axis=(-2, -1),
                              dtype=self.dtype,
                              name="out")(x)
     return output
    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        deterministic=True,
        output_attentions: bool = False,
    ):
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        query_states = self.query(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))
        value_states = self.value(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))
        key_states = self.key(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        # Mask heads if we want to
        if layer_head_mask is not None:
            attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights,
                                      layer_head_mask)

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights,
                                 value_states)
        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1, ))

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Beispiel #9
0
    def __call__(self,
                 inputs,
                 deterministic=False,
                 rng=None,
                 broadcast_dims=()):
        """Applies a random dropout mask to the input.

    Args:
      inputs: the inputs that should be randomly masked.
      deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
        masked, whereas if true, no mask is applied and the inputs are returned as
        is.
      rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will
        be used.

    Returns:
      The masked inputs reweighted to preserve mean.
    """
        if self.rate == 0.:
            return inputs
        keep_prob = 1. - self.rate
        if deterministic:
            return inputs
        else:
            if rng is None:
                rng = self.make_rng('dropout')
            broadcast_shape = list(inputs.shape)
            for dim in broadcast_dims:
                broadcast_shape[dim] = 1
            mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
            mask = jnp.broadcast_to(mask, inputs.shape)
            return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
Beispiel #10
0
def dropout(inputs, rate, deterministic=False, rng=None):
    """DEPRECATION WARNING:
  The `flax.nn` module is Deprecated, use `flax.linen` instead. 
  Learn more and find an upgrade guide at 
  https://github.com/google/flax/blob/master/flax/linen/README.md"
  Applies a random dropout mask to the input.

  Args:
    inputs: the inputs that should be randomly masked.
    rate: the probablity of masking out a value.
    deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
      masked, whereas if true, no mask is applied and the inputs are returned as
      is.
    rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will
      be used.
  Returns:
    The masked inputs.
  """
    if rate == 0.:
        return inputs
    keep_prob = 1. - rate

    if deterministic:
        return inputs
    else:
        if rng is None:
            rng = make_rng()
        mask = random.bernoulli(rng, p=keep_prob, shape=inputs.shape)
        return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
Beispiel #11
0
def assert_discharge_rule(error, enabled_errors, pred, code, payload, *, msgs):
    if ErrorCategory.USER_CHECK not in enabled_errors:
        return [], error

    out_err = error.err | jnp.logical_not(pred)
    out_code = lax.select(error.err, error.code, code)
    return [], Error(out_err, out_code, {**error.msgs, **msgs}, payload)
Beispiel #12
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
Beispiel #13
0
def dropout(inputs, rate, deterministic=False, rng=None):
    """Applies a random dropout mask to the input.

  Args:
    inputs: the inputs that should be randomly masked.
    rate: the probablity of masking out a value.
    deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
      masked, whereas if true, no mask is applied and the inputs are returned as
      is.
    rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will
      be used.
  Returns:
    The masked inputs.
  """
    if rate == 0.:
        return inputs
    keep_prob = 1. - rate

    if deterministic:
        return inputs
    else:
        if rng is None:
            rng = make_rng()
        mask = random.bernoulli(rng, p=keep_prob, shape=inputs.shape)
        return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
Beispiel #14
0
def where(condition, x=None, y=None):
    if x is None or y is None:
        raise ValueError("Must use the three-argument form of where().")
    if not onp.issubdtype(_dtype(condition), onp.bool_):
        condition = lax.ne(condition, zeros_like(condition))
    condition, x, y = broadcast_arrays(condition, x, y)
    return lax.select(condition, *_promote_dtypes(x, y))
Beispiel #15
0
def forward_func(z, x, v, δ):
    # Clip first two state components below at -500, in original domain corresponding to
    # exp(-500) ≈ 7 × 10^(-218) when updating state to prevent numerical NaN issues when
    # these state components tends to negative infinity. 500 was chosen as the cutoff to
    # avoid underflow / overflow as in double precision exp(-500) is non-zero and
    # exp(500) finite while for example exp(-1000) = 0 and exp(1000) = inf
    # We clip before and after _forward_func to avoid NaN gradients
    # https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf
    x = x.at[:2].set(jnp.clip(x[:2], -500))
    x_ = _forward_func(z, x, v, δ)
    return jnp.array(
        [
            lax.select(x[0] > -500, x_[0], x[0]),
            lax.select(x[1] > -500, x_[1], x[1]),
            x_[2],
        ]
    )
Beispiel #16
0
def remainder(x1, x2):
    x1, x2 = _promote_args("remainder", x1, x2)
    zero = _constant_like(x1, 0)
    trunc_mod = lax.rem(x1, x2)
    trunc_mod_not_zero = lax.ne(trunc_mod, zero)
    do_plus = lax.bitwise_and(
        lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
    return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
Beispiel #17
0
def _abs_taylor_rule(x, series_in, **params):
  x, = x
  zero = lax.full_like(x, 0, shape=())
  primal_out = lax.abs_p.bind(x, **params)
  negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
  fix_sign = lambda y: negs * y
  series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
  return primal_out, series_out
Beispiel #18
0
def cdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
    half = _constant_like(x, 0.5)
    one = _constant_like(x, 1)
    zero = _constant_like(x, 0)
    diff = lax.div(lax.sub(x, loc), scale)
    return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)),
                      lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
Beispiel #19
0
 def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng_factory):
     rng = rng_factory(self.rng())
     op = lambda c, x, y: lax.select(c < 0, x, y)
     self._CheckBatching(op, 5, bdims, (
         pred_shape,
         arg_shape,
         arg_shape,
     ), (onp.bool_, arg_dtype, arg_dtype), rng)
Beispiel #20
0
def _wrap_between(x, _a):
    """Wraps `x` between `[-a, a]`."""
    a = _constant_like(x, _a)
    two_a = _constant_like(x, 2 * _a)
    zero = _constant_like(x, 0)
    rem = lax.rem(lax.add(x, a), two_a)
    rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
    return lax.sub(rem, a)
Beispiel #21
0
 def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda c, x, y: lax.select(c < 0, x, y)
     self._CheckBatching(op, 5, bdims, (
         pred_shape,
         arg_shape,
         arg_shape,
     ), (np.bool_, arg_dtype, arg_dtype), rng)
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
    ):
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        causal_attention_mask = None
        if self.causal:
            query_length, key_length = query.shape[1], key.shape[1]
            causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]

        if attention_mask is not None and causal_attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4")
        elif causal_attention_mask is not None:
            attention_mask = causal_attention_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs
Beispiel #23
0
def _logaddexp(x1, x2):
  """
  Logaddexp while ignoring the custom_jvp rule.
  """
  amax = lax.max(x1, x2)
  delta = lax.sub(x1, x2)
  return lax.select(jnp.isnan(delta),
                    lax.add(x1, x2),  # NaNs or infinities of the same sign.
                    lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta)))))
Beispiel #24
0
def onehot(labels,
           num_classes,
           on_value=1.0,
           off_value=0.0,
           dtype=jnp.float32):
    x = (labels[..., None] == jnp.arange(num_classes)[None])
    x = lax.select(x, jnp.full(x.shape, on_value),
                   jnp.full(x.shape, off_value))
    return x.astype(dtype)
Beispiel #25
0
 def fn(x1, x2):
   x1, x2 =  _promote_args(numpy_fn.__name__, x1, x2)
   # Comparison on complex types are defined as a lexicographic ordering on
   # the (real, imag) pair.
   if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
     rx = lax.real(x1)
     ry = lax.real(x2)
     return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
                       lax_fn(rx, ry))
   return lax_fn(x1, x2)
Beispiel #26
0
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        """Input shape: Batch x Time x Channel"""

        # get query proj
        query_states = self.q_proj(hidden_states)

        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = self._split_heads(query_states)
        key_states = self._split_heads(key_states)
        value_states = self._split_heads(value_states)

        if attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights
Beispiel #27
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
Beispiel #28
0
    def __call__(self, x, deterministic=False, rng=None):
        if self.rate == 0.:
            return x
        keep_prob = 1. - self.rate

        if deterministic:
            return x
        else:
            if rng is None:
                rng = self.scope.make_rng('dropout')
            mask = random.bernoulli(rng, p=keep_prob, shape=x.shape)
            return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
    def __call__(self,
                 hidden_states,
                 attention_mask,
                 deterministic=True,
                 output_attentions: bool = False):
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        query_states = self.query(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))
        value_states = self.value(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))
        key_states = self.key(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_output = dot_product_attention(
            query_states,
            key_states,
            value_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        outputs = (attn_output.reshape(attn_output.shape[:2] + (-1, )), )

        # TODO: at the moment it's not possible to retrieve attn_weights from
        # dot_product_attention, but should be in the future -> add functionality then

        return outputs
Beispiel #30
0
def _lax_min_taylor_rule(primal_in, series_in):
    x, y = primal_in
    xgy = x < y   # less than mask
    xey = x == y  # equal to mask
    primal_out = lax.select(xgy, x, y)

    def select_min_and_avg_eq(x_i, y_i):
        """Select x where x>y or average when x==y"""
        min_i = lax.select(xgy, x_i, y_i)
        min_i = lax.select(xey, (x_i + y_i)/2, min_i)
        return min_i

    series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
    return primal_out, series_out