示例#1
0
def array(object, dtype=None, copy=True, order="K", ndmin=0):
    del copy  # Unused.
    if ndmin != 0 or order != "K":
        raise NotImplementedError("Only implemented for order='K', ndmin=0.")

    if hasattr(object, '__asarray__'):
        return object.__asarray__(dtype)
    elif isinstance(object, ndarray):
        if dtype and _dtype(object) != dtype:
            return lax.convert_element_type(object, dtype)
        else:
            return object
    elif isinstance(object, (list, tuple)):
        if object:
            subarrays = [
                expand_dims(array(elt, dtype=dtype), 0) for elt in object
            ]
            return concatenate(subarrays)
        else:
            return onp.array([], dtype)
    elif isscalar(object):
        out = lax.reshape(object, ())
        if dtype and _dtype(out) != dtype:
            return lax.convert_element_type(out, dtype)
        else:
            return out
    else:
        raise TypeError("Unexpected input type for array: {}".format(
            type(object)))
示例#2
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
示例#3
0
def clip(a, a_min=None, a_max=None):
    a_min = _dtype_info(_dtype(a)).min if a_min is None else a_min
    a_max = _dtype_info(_dtype(a)).max if a_max is None else a_max
    if _dtype(a_min) != _dtype(a):
        a_min = lax.convert_element_type(a_min, _dtype(a))
    if _dtype(a_max) != _dtype(a):
        a_max = lax.convert_element_type(a_max, _dtype(a))
    return lax.clamp(a_min, a, a_max)
示例#4
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.probs)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     return gammaln(total_count + 1) + np.sum(
         xlogy(value, self.probs) - gammaln(value + 1), axis=-1)
示例#5
0
def transform(T, v):
    T_dtype, v_dtype = _dtype(T), _dtype(v)
    if T_dtype != v_dtype:
        higher_dtype = lax.dtypes.promote_types(T_dtype, v_dtype)
        if higher_dtype == v_dtype:
            T = lax.convert_element_type(T, v_dtype)
        else:
            v = lax.convert_element_type(v, T_dtype)
    return transform_p.bind(T, v)
示例#6
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
示例#7
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.logits)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     normalize_term = total_count * logsumexp(
         self.logits, axis=-1) - gammaln(total_count + 1)
     return np.sum(value * self.logits - gammaln(value + 1),
                   axis=-1) - normalize_term
示例#8
0
def _poisson(key, rate, shape, dtype):
    # Ref: https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables
    shape = shape or np.shape(rate)
    rate = lax.convert_element_type(rate, canonicalize_dtype(np.float64))
    rate = np.broadcast_to(rate, shape)
    rng_keys = random.split(key, np.size(rate))
    if xla_bridge.get_backend().platform == 'cpu':
        k = lax.map(_poisson_one, (rng_keys, np.reshape(rate, -1)))
    else:
        k = vmap(_poisson_one)((rng_keys, np.reshape(rate, -1)))
    k = lax.convert_element_type(k, dtype)
    return np.reshape(k, shape)
示例#9
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.probs)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     log_factorial_n = gammaln(total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(total_count - value + 1)
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             xlogy(value, self.probs) +
             xlog1py(total_count - value, -self.probs))
示例#10
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.logits)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     log_factorial_n = gammaln(total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(total_count - value + 1)
     normalize_term = (total_count * np.clip(self.logits, 0) +
                       xlog1py(total_count, np.exp(-np.abs(self.logits))) -
                       log_factorial_n)
     return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
示例#11
0
文件: scatter.py 项目: sharadmv/jax
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, normalize_indices):
    dtype = lax.dtype(x)
    x, y = jnp._promote_dtypes(x, y)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indices_are_sorted,
                     unique_indices=unique_indices)
    return lax.convert_element_type(out, dtype)
示例#12
0
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
            antialias: bool, precision):
    if len(shape) != image.ndim:
        msg = (
            'shape must have length equal to the number of dimensions of x; '
            f' {shape} vs {image.shape}')
        raise ValueError(msg)
    if isinstance(method, str):
        method = ResizeMethod.from_string(method)
    if method == ResizeMethod.NEAREST:
        return _resize_nearest(image, shape)
    assert isinstance(method, ResizeMethod)
    kernel = _kernels[method]

    if not jnp.issubdtype(image.dtype, jnp.inexact):
        image = lax.convert_element_type(image,
                                         jnp.result_type(image, jnp.float32))
    # Skip dimensions that have scale=1 and translation=0, this is only possible
    # since all of the current resize methods (kernels) are interpolating, so the
    # output = input under an identity warp.
    spatial_dims = tuple(
        i for i in range(len(shape))
        if not core.symbolic_equal_dim(image.shape[i], shape[i]))
    scale = [
        1.0 if core.symbolic_equal_dim(
            shape[d], 0) else core.dimension_as_value(shape[d]) /
        core.dimension_as_value(image.shape[d]) for d in spatial_dims
    ]
    return _scale_and_translate(image, shape, spatial_dims, scale,
                                [0.] * len(spatial_dims), kernel, antialias,
                                precision)
示例#13
0
def _ravel_list(lst):
  if not lst: return jnp.array([], jnp.float32), lambda _: []
  from_dtypes = [dtypes.dtype(l) for l in lst]
  to_dtype = dtypes.result_type(*from_dtypes)
  sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
  indices = np.cumsum(sizes)

  if all(dt == to_dtype for dt in from_dtypes):
    # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
    # See https://github.com/google/jax/issues/7809.
    del from_dtypes, to_dtype
    def unravel(arr):
      chunks = jnp.split(arr, indices[:-1])
      return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
    raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
    return raveled, unravel

  # When there is more than one distinct input dtype, we perform type
  # conversions and produce a dtype-specific unravel function.
  def unravel(arr):
    arr_dtype = dtypes.dtype(arr)
    if arr_dtype != to_dtype:
      raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
                      f"but expected dtype {to_dtype}")
    chunks = jnp.split(arr, indices[:-1])
    with warnings.catch_warnings():
      warnings.simplefilter("ignore")  # ignore complex-to-real cast warning
      return [lax.convert_element_type(chunk.reshape(shape), dtype)
              for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]

  ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
  raveled = jnp.concatenate([ravel(e) for e in lst])
  return raveled, unravel
示例#14
0
文件: fft.py 项目: qqsun8819/jax
def _promote_to_real(arg):
    dtype = dtypes.result_type(arg, np.float32)
    # XLA's FFT op only supports F32.
    # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
    if lib.version <= (0, 1, 47) and dtype == np.float64:
        dtype = np.float32
    return lax.convert_element_type(arg, dtype)
示例#15
0
  def reduction(a, axis=None, dtype=None, out=None, keepdims=False):
    if out is not None:
      raise ValueError("reduction does not support `out` argument.")

    a = a if isinstance(a, ndarray) else asarray(a)
    dims = _reduction_dims(a, axis)
    result_dtype = _dtype(np_fun(onp.ones((), dtype=_dtype(a))))
    if _dtype(a) != result_dtype:
      a = lax.convert_element_type(a, result_dtype)
    result = lax.reduce(a, _reduction_init_val(a, init_val), op, dims)
    if keepdims:
      shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims)))
      result = lax.reshape(result, shape_with_singletons)
    if dtype and onp.dtype(dtype) != onp.dtype(result_dtype):
      result = lax.convert_element_type(result, dtype)
    return result
示例#16
0
文件: fft.py 项目: qqsun8819/jax
def _promote_to_complex(arg):
    dtype = dtypes.result_type(arg, np.complex64)
    # XLA's FFT op only supports C64 in jaxlib versions 0.1.47 and earlier.
    # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer.
    if lib.version <= (0, 1, 47) and dtype == np.complex128:
        dtype = np.complex64
    return lax.convert_element_type(arg, dtype)
示例#17
0
def nanmean(a,
            axis: Optional[Union[int, Tuple[int, ...]]] = None,
            dtype=None,
            out=None,
            keepdims=False,
            where=None):
    _check_arraylike("nanmean", a)
    lax_internal._check_user_dtype_supported(dtype, "nanmean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanmean is not supported.")
    if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(
            dtypes.dtype(a), np.integer):
        return mean(a, axis, dtype, out, keepdims, where=where)
    if dtype is None:
        dtype = dtypes.dtype(a)
    nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a))
    normalizer = sum(nan_mask,
                     axis=axis,
                     dtype=np.int32,
                     keepdims=keepdims,
                     where=where)
    normalizer = lax.convert_element_type(normalizer, dtype)
    td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                 normalizer)
    return td
示例#18
0
    def _cumulative_reduction(a,
                              axis: Optional[Union[int, Tuple[int,
                                                              ...]]] = None,
                              dtype=None,
                              out=None):
        _check_arraylike(np_reduction.__name__, a)
        if out is not None:
            raise NotImplementedError(
                f"The 'out' argument to jnp.{np_reduction.__name__} "
                f"is not supported.")
        lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

        if axis is None or _isscalar(a):
            a = lax.reshape(a, (np.size(a), ))
            axis = 0

        a_shape = list(np.shape(a))
        num_dims = len(a_shape)
        axis = _canonicalize_axis(axis, num_dims)

        if fill_nan:
            a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

        if not dtype and dtypes.dtype(a) == np.bool_:
            dtype = dtypes.canonicalize_dtype(dtypes.int_)
        if dtype:
            a = lax.convert_element_type(a, dtype)

        return reduction(a, axis)
示例#19
0
    def init(self, rng_key, *args, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs,
                                                  **self.static_kwargs)
        model_trace = trace(replay(model_init, guide_trace)).get_trace(
            *args, **kwargs, **self.static_kwargs)
        params = {}
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                inv_transforms[site['name']] = transform
                params[site['name']] = transform.inv(site['value'])

        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params = tree_map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)), params)
        return SVIState(self.optim.init(params), rng_key)
示例#20
0
def _mean(a,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
          dtype=None,
          out=None,
          keepdims=False,
          *,
          where=None):
    _check_arraylike("mean", a)
    lax_internal._check_user_dtype_supported(dtype, "mean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.mean is not supported.")

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)

    if dtype is None:
        dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
    dtype = dtypes.canonicalize_dtype(dtype)

    return lax.div(sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                   lax.convert_element_type(normalizer, dtype))
示例#21
0
    def init(rng, shape):
        # Check the shape
        std = lax.convert_element_type(stddev, dtype)
        if len(shape) < 2:
            raise ValueError('The array to initialize must be '
                             'at least two-dimensional')
        # Flatten the input shape with the last dimension remaining
        # its original shape so it works for conv2d
        num_rows = 1
        for dim in shape[:-1]:
            num_rows *= dim
        num_cols = shape[-1]
        flat_shape = (num_cols,
                      num_rows) if num_rows < num_cols else (num_rows,
                                                             num_cols)

        # Generate a random matrix
        a = random.normal(rng, flat_shape, dtype=dtype)
        # Compute the qr factorization
        q, r = np.linalg.qr(a)
        # Make Q uniform
        d = np.diag(r)
        q *= np.sign(d)
        if num_rows < num_cols:
            q = np.transpose(q)
        return std * np.reshape(q, shape)
示例#22
0
def threefry_seed(seed: int) -> jnp.ndarray:
    """Create a single raw threefry PRNG key given an integer seed.

  Args:
    seed: a 64- or 32-bit integer used as the value of the key.

  Returns:
    The PRNG key contents, modeled as an array of shape (2,) and dtype
    uint32. The key is constructed from a 64-bit seed by effectively
    bit-casting to a pair of uint32 values (or from a 32-bit seed by
    first padding out with zeros).
  """
    # Avoid overflowerror in X32 mode by first converting ints to int64.
    # This breaks JIT invariance for large ints, but supports the common
    # use-case of instantiating with Python hashes in X32 mode.
    if isinstance(seed, int):
        seed_arr = jnp.asarray(np.int64(seed))
    else:
        seed_arr = jnp.asarray(seed)
    if seed_arr.shape:
        raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
    if not np.issubdtype(seed_arr.dtype, np.integer):
        raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")

    convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32),
                                    [1])
    k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
    k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
    return lax.concatenate([k1, k2], 0)
示例#23
0
    def __init__(self,
                 v=0.,
                 log_density=0.,
                 event_dim=0,
                 validate_args=None,
                 value=None):
        if value is not None:
            v = value
            warnings.warn(
                "`value` argument has been deprecated in favor of `v` argument.",
                FutureWarning)

        if event_dim > jnp.ndim(v):
            raise ValueError(
                'Expected event_dim <= v.dim(), actual {} vs {}'.format(
                    event_dim, jnp.ndim(v)))
        batch_dim = jnp.ndim(v) - event_dim
        batch_shape = jnp.shape(v)[:batch_dim]
        event_shape = jnp.shape(v)[batch_dim:]
        self.v = lax.convert_element_type(v, canonicalize_dtype(jnp.float64))
        # NB: following Pyro implementation, log_density should be broadcasted to batch_shape
        self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
        super(Delta, self).__init__(batch_shape,
                                    event_shape,
                                    validate_args=validate_args)
示例#24
0
def rint(x):
    _check_arraylike('rint', x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.integer):
        return lax.convert_element_type(x, dtypes.float_)
    if dtypes.issubdtype(dtype, np.complexfloating):
        return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
    return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
示例#25
0
def ldexp(x1, x2):
    _check_arraylike("ldexp", x1, x2)
    x1_dtype = dtypes.dtype(x1)
    x2_dtype = dtypes.dtype(x2)
    if (dtypes.issubdtype(x1_dtype, np.complexfloating)
            or dtypes.issubdtype(x2_dtype, np.inexact)):
        raise ValueError(
            f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")

    x1, x2 = _promote_shapes("ldexp", x1, x2)

    dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype))
    info = dtypes.finfo(dtype)
    int_type = _INT_DTYPES[info.bits]

    x1 = lax.convert_element_type(x1, dtype)
    x2 = lax.convert_element_type(x2, int_type)

    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1
    x, e = _normalize_float(x1)
    x2 += e + ((x >> info.nmant) & mask) - bias

    # find underflow/overflow before denormalization
    underflow_cond = x2 < -(bias + info.nmant)
    overflow_cond = x2 > bias

    m = lax.full_like(x, 1, dtype=dtype)

    # denormals
    cond = x2 < -bias + 1
    x2 = _where(cond, x2 + info.nmant, x2)
    m = _where(cond, m / (1 << info.nmant), m)

    x2 = lax.convert_element_type(x2, np.int32)
    x &= ~(mask << info.nmant)
    x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

    x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

    # underflow
    x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
    # overflow
    x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
    # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
    return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
示例#26
0
def randn(stddev=1e-2):
    """An initializer function for random normal coefficients."""
    stddev = lax.convert_element_type(stddev, np.float32)

    def init(rng, shape):
        return stddev * random.normal(rng, shape, dtype=np.float32)

    return init
示例#27
0
 def unravel(arr):
     chunks = jnp.split(arr, indices[:-1])
     with warnings.catch_warnings():
         warnings.simplefilter(
             "ignore")  # ignore complex-to-real cast warning
         return [
             lax.convert_element_type(chunk.reshape(shape), dtype)
             for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
         ]
示例#28
0
def mean(a, axis=None, keepdims=False):
    if axis is None:
        normalizer = size(a)
    else:
        normalizer = onp.prod(onp.take(shape(a), axis))
    if onp.issubdtype(_dtype(a), onp.bool_):
        a = lax.convert_element_type(a, onp.int32)
    return true_divide(sum(a, axis, keepdims=keepdims),
                       _constant_like(a, normalizer))
示例#29
0
 def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory):
   rng = rng_factory(self.rng())
   tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
             jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
   args = (rng((2, 3), from_dtype),)
   convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
   convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)(
     convert_element_type)
   check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
示例#30
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     concentration = lax.convert_element_type(self.concentration,
                                              value.dtype)
     normalize_term = (np.sum(gammaln(concentration), axis=-1) -
                       gammaln(np.sum(concentration, axis=-1)))
     return np.sum(np.log(value) *
                   (concentration - 1.), axis=-1) - normalize_term