Ejemplo n.º 1
0
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None,
                    mode: Optional[lax.GatherScatterMode] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")


    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        out = jnp.full((num_segments, ) + data.shape[1:],
                       _get_identity(scatter_op, dtype),
                       dtype=dtype)
        return _scatter_update(out,
                               segment_ids,
                               data,
                               scatter_op,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False,
                               mode=mode)

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    out = jnp.full((num_buckets, num_segments) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
                   dtype=dtype)
    out = _scatter_update(
        out,
        np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size),
                     segment_ids[None, :]],
        data,
        scatter_op,
        indices_are_sorted,
        unique_indices,
        normalize_indices=False,
        mode=mode)
    return reducer(out, axis=0).astype(dtype)
Ejemplo n.º 2
0
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")

    out = jnp.full((num_segments, ) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
                   dtype=dtype)

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return _scatter_update(out,
                               segment_ids,
                               data,
                               scatter_op,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False)

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
        outs.append(
            _segment_update(name, sub_data, sub_segment_ids, scatter_op,
                            num_segments, indices_are_sorted, unique_indices))
    return reducer(jnp.stack(outs), axis=0).astype(dtype)
Ejemplo n.º 3
0
def polyint(p, m=1, k=None):
  m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
  k = 0 if k is None else k
  _check_arraylike("polyint", p, k)
  p, k = _promote_dtypes_inexact(p, k)
  if m < 0:
    raise ValueError("Order of integral must be positive (see polyder)")
  k = atleast_1d(k)
  if len(k) == 1:
    k = full((m,), k[0])
  if k.shape != (m,):
    raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
  if m == 0:
    return p
  else:
    coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
    return true_divide(concatenate((p, k)), coeff)
Ejemplo n.º 4
0
def _ndtri(p):
    """Implements ndtri core logic."""

    # Constants used in piece-wise rational approximations. Taken from the cephes
    # library:
    # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    p0 = list(
        reversed([
            -5.99633501014107895267E1, 9.80010754185999661536E1,
            -5.66762857469070293439E1, 1.39312609387279679503E1,
            -1.23916583867381258016E0
        ]))
    q0 = list(
        reversed([
            1.0, 1.95448858338141759834E0, 4.67627912898881538453E0,
            8.63602421390890590575E1, -2.25462687854119370527E2,
            2.00260212380060660359E2, -8.20372256168333339912E1,
            1.59056225126211695515E1, -1.18331621121330003142E0
        ]))
    p1 = list(
        reversed([
            4.05544892305962419923E0, 3.15251094599893866154E1,
            5.71628192246421288162E1, 4.40805073893200834700E1,
            1.46849561928858024014E1, 2.18663306850790267539E0,
            -1.40256079171354495875E-1, -3.50424626827848203418E-2,
            -8.57456785154685413611E-4
        ]))
    q1 = list(
        reversed([
            1.0, 1.57799883256466749731E1, 4.53907635128879210584E1,
            4.13172038254672030440E1, 1.50425385692907503408E1,
            2.50464946208309415979E0, -1.42182922854787788574E-1,
            -3.80806407691578277194E-2, -9.33259480895457427372E-4
        ]))
    p2 = list(
        reversed([
            3.23774891776946035970E0, 6.91522889068984211695E0,
            3.93881025292474443415E0, 1.33303460815807542389E0,
            2.01485389549179081538E-1, 1.23716634817820021358E-2,
            3.01581553508235416007E-4, 2.65806974686737550832E-6,
            6.23974539184983293730E-9
        ]))
    q2 = list(
        reversed([
            1.0, 6.02427039364742014255E0, 3.67983563856160859403E0,
            1.37702099489081330271E0, 2.16236993594496635890E-1,
            1.34204006088543189037E-2, 3.28014464682127739104E-4,
            2.89247864745380683936E-6, 6.79019408009981274425E-9
        ]))

    dtype = lax.dtype(p).type
    shape = jnp.shape(p)

    def _create_polynomial(var, coeffs):
        """Compute n_th order polynomial via Horner's method."""
        coeffs = np.array(coeffs, dtype)
        if not coeffs.size:
            return jnp.zeros_like(var)
        return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var

    maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p)
    # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
    # later on. The result from the computation when p == 0 is not used so any
    # number that doesn't result in NaNs is fine.
    sanitized_mcp = jnp.where(maybe_complement_p <= dtype(0.),
                              jnp.full(shape, dtype(0.5)), maybe_complement_p)

    # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
    w = sanitized_mcp - dtype(0.5)
    ww = lax.square(w)
    x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) /
                                _create_polynomial(ww, q0))
    x_for_big_p *= -dtype(np.sqrt(2. * np.pi))

    # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
    # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
    # arrays based on whether p < exp(-32).
    z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
    first_term = z - lax.log(z) / z
    second_term_small_p = (_create_polynomial(dtype(1.) / z, p2) /
                           _create_polynomial(dtype(1.) / z, q2) / z)
    second_term_otherwise = (_create_polynomial(dtype(1.) / z, p1) /
                             _create_polynomial(dtype(1.) / z, q1) / z)
    x_for_small_p = first_term - second_term_small_p
    x_otherwise = first_term - second_term_otherwise

    x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)), x_for_big_p,
                  jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise))

    x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x)
    infinity = jnp.full(shape, dtype(np.inf))
    x_nan_replaced = jnp.where(p <= dtype(0.0), -infinity,
                               jnp.where(p >= dtype(1.0), infinity, x))
    return x_nan_replaced