Пример #1
0
def jvpfun(instantiate, transform_stack, primals, tangents):
    tangents = [
        Zero.from_value(t)
        if not isinstance(t, Zero) and dtype(t) is float0 else t
        for t in tangents
    ]
    ctx = (source_info_util.transform_name_stack('jvp')
           if transform_stack else contextlib.nullcontext())
    with core.new_main(JVPTrace) as main, ctx:
        out_primals, out_tangents = yield (main, primals, tangents), {}
        del main
    if type(instantiate) is bool:
        instantiate = [instantiate] * len(out_tangents)
    out_tangents = [
        instantiate_zeros(t) if inst else t
        for t, inst in zip(out_tangents, instantiate)
    ]
    yield out_primals, out_tangents
Пример #2
0
def _power(x1, x2):
  x1, x2 = _promote_args("power", x1, x2)
  dtype = dtypes.dtype(x1)
  if not dtypes.issubdtype(dtype, np.integer):
    return lax.pow(x1, x2)

  # Integer power => use binary exponentiation.

  # TODO(phawkins): add integer pow support to XLA.
  bits = 6  # Anything more would overflow for any x1 > 1
  zero = _constant_like(x2, 0)
  one = _constant_like(x2, 1)
  # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
  acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
  for _ in range(bits):
    acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
    x1 = lax.mul(x1, x1)
    x2 = lax.shift_right_logical(x2, one)
  return acc
Пример #3
0
def _ravel_list(lst):
    if not lst: return jnp.array([], jnp.float32), lambda _: []
    from_dtypes = [dtypes.dtype(l) for l in lst]
    to_dtype = dtypes.result_type(*from_dtypes)
    sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
    indices = np.cumsum(sizes)

    def unravel(arr):
        chunks = jnp.split(arr, indices[:-1])
        with warnings.catch_warnings():
            warnings.simplefilter(
                "ignore")  # ignore complex-to-real cast warning
            return [
                lax.convert_element_type(chunk.reshape(shape), dtype)
                for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
            ]

    ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
    raveled = jnp.concatenate([ravel(e) for e in lst])
    return raveled, unravel
Пример #4
0
def frexp(x):
    _check_arraylike("frexp", x)
    x, = _promote_dtypes_inexact(x)
    if dtypes.issubdtype(x.dtype, np.complexfloating):
        raise TypeError("frexp does not support complex-valued inputs")

    dtype = dtypes.dtype(x)
    info = dtypes.finfo(dtype)
    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1

    x1, x2 = _normalize_float(x)
    x2 += ((x1 >> info.nmant) & mask) - bias + 1
    x1 &= ~(mask << info.nmant)
    x1 |= (bias - 1) << info.nmant
    x1 = lax.bitcast_convert_type(x1, dtype)

    cond = isinf(x) | isnan(x) | (x == 0)
    x2 = _where(cond, lax_internal._zeros(x2), x2)
    return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
Пример #5
0
def _ravel_list(lst):
    if not lst: return jnp.array([], jnp.float32), lambda _: []
    from_dtypes = [dtypes.dtype(l) for l in lst]
    to_dtype = dtypes.result_type(*from_dtypes)
    sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
    indices = np.cumsum(sizes)

    if all(dt == to_dtype for dt in from_dtypes):
        # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
        # See https://github.com/google/jax/issues/7809.
        del from_dtypes, to_dtype

        def unravel(arr):
            chunks = jnp.split(arr, indices[:-1])
            return [
                chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)
            ]

        raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
        return raveled, unravel

    # When there is more than one distinct input dtype, we perform type
    # conversions and produce a dtype-specific unravel function.
    def unravel(arr):
        arr_dtype = dtypes.dtype(arr)
        if arr_dtype != to_dtype:
            raise TypeError(
                f"unravel function given array of dtype {arr_dtype}, "
                f"but expected dtype {to_dtype}")
        chunks = jnp.split(arr, indices[:-1])
        with warnings.catch_warnings():
            warnings.simplefilter(
                "ignore")  # ignore complex-to-real cast warning
            return [
                lax.convert_element_type(chunk.reshape(shape), dtype)
                for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
            ]

    ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
    raveled = jnp.concatenate([ravel(e) for e in lst])
    return raveled, unravel
Пример #6
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
Пример #7
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = nanmean(a,
                     axis,
                     dtype=computation_dtype,
                     keepdims=True,
                     where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      lax.sub(a, a_mean))  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
Пример #8
0
def _var(a,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         dtype=None,
         out=None,
         ddof=0,
         keepdims=False,
         *,
         where=None):
    _check_arraylike("var", a)
    lax_internal._check_user_dtype_supported(dtype, "var")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.var is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
    centered = lax.sub(a, a_mean)
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)
    normalizer = normalizer - ddof

    result = sum(centered, axis, keepdims=keepdims, where=where)
    out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
    return lax.convert_element_type(out, dtype)
Пример #9
0
def _closure_convert_for_avals(fun, in_tree, in_avals):
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
    out_tree = out_tree()

    # We only want to closure convert for constants with respect to which we're
    # differentiating. As a proxy for that, we hoist consts with float dtype.
    # TODO(frostig,mattjj): revise this approach
    from jax.numpy import inexact
    is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), inexact)
    (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(*args_hconsts):
        num_args = len(args_hconsts) - num_consts
        args, hoisted_consts = split_list(args_hconsts, [num_args])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten(tuple(args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
Пример #10
0
def _nan_reduction(a,
                   name,
                   jnp_reduction,
                   init_val,
                   nan_if_all_nan,
                   axis=None,
                   keepdims=None,
                   **kwargs):
    _check_arraylike(name, a)
    if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
        return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)

    out = jnp_reduction(_where(lax_internal._isnan(a),
                               _reduction_init_val(a, init_val), a),
                        axis=axis,
                        keepdims=keepdims,
                        **kwargs)
    if nan_if_all_nan:
        return _where(
            all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
            _lax_const(a, np.nan), out)
    else:
        return out
Пример #11
0
def signbit(x):
    x, = _promote_args("signbit", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.integer):
        return lax.lt(x, _constant_like(x, 0))
    elif dtypes.issubdtype(dtype, np.bool_):
        return lax.full_like(x, False, dtype=np.bool_)
    elif not dtypes.issubdtype(dtype, np.floating):
        raise ValueError("jax.numpy.signbit is not well defined for %s" %
                         dtype)

    # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
    # F32.
    if dtype == dtypes.bfloat16:
        dtype = np.float32
        x = lax.convert_element_type(x, np.float32)

    info = dtypes.finfo(dtype)
    if info.bits not in _INT_DTYPES:
        raise NotImplementedError(
            "jax.numpy.signbit only supports 16, 32, and 64-bit types.")
    int_type = _INT_DTYPES[info.bits]
    x = lax.bitcast_convert_type(x, int_type)
    return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
Пример #12
0
 def testDtypeFromValue(self, dtype):
     self.assertEqual(dtypes.dtype(dtype.type(0)), dtype)
Пример #13
0
 def testDtypeFromScalarValue(self, typ):
     self.assertEqual(dtypes.dtype(typ(0)),
                      dtypes.python_scalar_dtypes[typ])
Пример #14
0
def _result_dtype(op, *args):
  """Compute result dtype of applying op to arguments with given dtypes."""
  args = [np.ones((0,) * np.ndim(arg), dtypes.dtype(arg)) for arg in args]
  return dtypes.dtype(op(*args))
Пример #15
0
def _reduction(a,
               name,
               np_fun,
               op,
               init_val,
               has_identity=True,
               preproc=None,
               bool_op=None,
               upcast_f16_for_computation=False,
               axis=None,
               dtype=None,
               out=None,
               keepdims=False,
               initial=None,
               where_=None,
               parallel_reduce=None):
    bool_op = bool_op or op
    # Note: we must accept out=None as an argument, because numpy reductions delegate to
    # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
    # exists, passing along all its arguments.
    if out is not None:
        raise NotImplementedError(
            f"The 'out' argument to jnp.{name} is not supported.")
    _check_arraylike(name, a)
    lax_internal._check_user_dtype_supported(dtype, name)
    axis = core.concrete_or_error(None, axis,
                                  f"axis argument to jnp.{name}().")

    if initial is None and not has_identity and where_ is not None:
        raise ValueError(
            f"reduction operation {name} does not have an identity, so to use a "
            f"where mask one has to specify 'initial'")

    a = a if isinstance(a, ndarray) else _asarray(a)
    a = preproc(a) if preproc else a
    pos_dims, dims = _reduction_dims(a, axis)

    if initial is None and not has_identity:
        shape = np.shape(a)
        if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
            raise ValueError(
                f"zero-size array to reduction operation {name} which has no identity"
            )

    result_dtype = dtypes.canonicalize_dtype(
        dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a)))))
    if upcast_f16_for_computation and dtypes.issubdtype(
            result_dtype, np.inexact):
        computation_dtype = _upcast_f16(result_dtype)
    else:
        computation_dtype = result_dtype
    a = lax.convert_element_type(a, computation_dtype)
    op = op if computation_dtype != np.bool_ else bool_op
    # NB: in XLA, init_val must be an identity for the op, so the user-specified
    # initial value must be applied afterward.
    init_val = _reduction_init_val(a, init_val)
    if where_ is not None:
        a = _where(where_, a, init_val)
    if pos_dims is not dims:
        if parallel_reduce is None:
            raise NotImplementedError(
                f"Named reductions not implemented for jnp.{name}()")
        result = parallel_reduce(a, dims)
    else:
        result = lax.reduce(a, init_val, op, dims)
    if initial is not None:
        result = op(lax.convert_element_type(initial, a.dtype), result)
    if keepdims:
        result = lax.expand_dims(result, pos_dims)
    return lax.convert_element_type(result, dtype or result_dtype)
Пример #16
0
def divmod(x1, x2):
    x1, x2 = _promote_args("divmod", x1, x2)
    if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
        return floor_divide(x1, x2), remainder(x1, x2)
    else:
        return _float_divmod(x1, x2)
Пример #17
0
 def testDtypeFromNone(self):
     with self.assertRaisesRegex(ValueError, "Invalid argument to dtype"):
         dtypes.dtype(None)
Пример #18
0
Файл: ad.py Проект: jbampton/jax
def replace_float0s(primal, tangent):
  if dtype(tangent) is float0:
    return zeros_like_jaxval(primal)
  else:
    return tangent
Пример #19
0
 def dtype(self):
   return dtypes.dtype(self.val)
Пример #20
0
 def op(*args):
     zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
     args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(
         x, zero(x)) for x in args)
     return bitwise_op(*_promote_args(np_op.__name__, *args))
Пример #21
0
def absolute(x):
    _check_arraylike('absolute', x)
    dt = dtypes.dtype(x)
    return x if dt == np.bool_ or dtypes.issubdtype(
        dt, np.unsignedinteger) else lax.abs(x)
Пример #22
0
def _constant_like(x, const):
    return np.array(const, dtype=dtypes.dtype(x))
Пример #23
0
 def testDtypeFromDtype(self, dtype):
     self.assertEqual(dtypes.dtype(dtype), dtype)
Пример #24
0
Файл: ad.py Проект: jbampton/jax
def recast_to_float0(primal, tangent):
  if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
    return Zero(get_aval(primal).at_least_vspace())
  else:
    return tangent
Пример #25
0
 def testDtypeFromString(self, dtype):
     self.assertEqual(dtypes.dtype(str(dtype)), dtype)
Пример #26
0
def _canonicalize_arg(arg):
  return np.asarray(arg, dtype=dtypes.dtype(arg, canonicalize=True), order='C')
Пример #27
0
def _bint_ir_types(aval: core.AbstractBInt) -> Sequence[ir.Type]:
    return (ir.RankedTensorType.get((),
                                    dtype_to_ir_type(dtypes.dtype('int32'))), )
Пример #28
0
def copysign(x1, x2):
    x1, x2 = _promote_args_inexact("copysign", x1, x2)
    if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
        raise TypeError("copysign does not support complex-valued inputs")
    return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))