示例#1
0
文件: fft.py 项目: xueeinstein/jax
def _fft_core(func_name, fft_type, a, s, axes, norm):
    full_name = "jax.numpy.fft." + func_name

    if s is not None:
        s = tuple(map(operator.index, s))
        if np.any(np.less(s, 0)):
            raise ValueError("Shape should be non-negative.")

    if s is not None and axes is not None and len(s) != len(axes):
        # Same error as numpy.
        raise ValueError("Shape and axes have different lengths.")

    orig_axes = axes
    if axes is None:
        if s is None:
            axes = range(a.ndim)
        else:
            axes = range(a.ndim - len(s), a.ndim)

    if len(axes) != len(set(axes)):
        raise ValueError(
            f"{full_name} does not support repeated axes. Got axes {axes}.")

    if len(axes) > 3:
        # XLA does not support FFTs over more than 3 dimensions
        raise ValueError("%s only supports 1D, 2D, and 3D FFTs. "
                         "Got axes %s with input rank %s." %
                         (full_name, orig_axes, a.ndim))

    # XLA only supports FFTs over the innermost axes, so rearrange if necessary.
    if orig_axes is not None:
        axes = tuple(range(a.ndim - len(axes), a.ndim))
        a = jnp.moveaxis(a, orig_axes, axes)

    if s is not None:
        a = jnp.asarray(a)
        in_s = list(a.shape)
        for axis, x in safe_zip(axes, s):
            in_s[axis] = x
        if fft_type == xla_client.FftType.IRFFT:
            in_s[-1] = (in_s[-1] // 2 + 1)
        # Cropping
        a = a[tuple(map(slice, in_s))]
        # Padding
        a = jnp.pad(a, [(0, x - y) for x, y in zip(in_s, a.shape)])
    else:
        if fft_type == xla_client.FftType.IRFFT:
            s = [a.shape[axis] for axis in axes[:-1]]
            if axes:
                s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
        else:
            s = [a.shape[axis] for axis in axes]
    transformed = lax.fft(a, fft_type, tuple(s))
    transformed *= _fft_norm(jnp.array(s, dtype=transformed.dtype), func_name,
                             norm)

    if orig_axes is not None:
        transformed = jnp.moveaxis(transformed, axes, orig_axes)
    return transformed
示例#2
0
def _overlap_and_add(x, step_size):
    """Utility function compatible with tf.signal.overlap_and_add.

  Args:
    x: An array with `(..., frames, frame_length)`-shape.
    step_size: An integer denoting overlap offsets. Must be less than
      `frame_length`.

  Returns:
    An array with `(..., output_size)`-shape containing overlapped signal.
  """
    _check_arraylike("_overlap_and_add", x)
    step_size = jax.core.concrete_or_error(int, step_size,
                                           "step_size for overlap_and_add")
    if x.ndim < 2:
        raise ValueError('Input must have (..., frames, frame_length) shape.')

    *batch_shape, nframes, segment_len = x.shape
    flat_batchsize = np.prod(batch_shape, dtype=np.int64)
    x = x.reshape((flat_batchsize, nframes, segment_len))
    output_size = step_size * (nframes - 1) + segment_len
    nstep_per_segment = 1 + (segment_len - 1) // step_size

    # Here, we use shorter notation for axes.
    # B: batch_size, N: nframes, S: nstep_per_segment,
    # T: segment_len divided by S

    padded_segment_len = nstep_per_segment * step_size
    x = jnp.pad(x, ((0, 0), (0, 0), (0, padded_segment_len - segment_len)))
    x = x.reshape((flat_batchsize, nframes, nstep_per_segment, step_size))

    # For obtaining shifted signals, this routine reinterprets flattened array
    # with a shrinked axis.  With appropriate truncation/ padding, this operation
    # pushes the last padded elements of the previous row to the head of the
    # current row.
    # See implementation of `overlap_and_add` in Tensorflow for details.
    x = x.transpose((0, 2, 1, 3))  # x: (B, S, N, T)
    x = jnp.pad(x, ((0, 0), (0, 0), (0, nframes), (0, 0)))  # x: (B, S, N*2, T)
    shrinked = x.shape[2] - 1
    x = x.reshape((flat_batchsize, -1))
    x = x[:, :(nstep_per_segment * shrinked * step_size)]
    x = x.reshape((flat_batchsize, nstep_per_segment, shrinked * step_size))

    # Finally, sum shifted segments, and truncate results to the output_size.
    x = x.sum(axis=1)[:, :output_size]
    return x.reshape(tuple(batch_shape) + (-1, ))
示例#3
0
文件: signal.py 项目: GJBoth/jax
 def pad(x, n, axis=-1):
     pad_width = [(0, 0) for unused_n in range(x.ndim)]
     pad_width[axis] = (n, n)
     return jnp.pad(x, pad_width, mode, **kwargs)
示例#4
0
def _gen_derivatives(p: jnp.ndarray, x: jnp.ndarray,
                     is_normalized: bool) -> jnp.ndarray:
    """Generates derivatives of associated Legendre functions of the first kind.

  Args:
    p: The 3D array containing the values of associated Legendre functions; the
      dimensions are in the sequence of order (m), degree (l), and evalution
      points.
    x: A vector of type `float32` or `float64` containing the sampled points.
    is_normalized: True if the associated Legendre functions are normalized.
  Returns:
    The 3D array representing the derivatives of associated Legendre functions
    of the first kind.
  """

    num_m, num_l, num_x = p.shape

    # p_{l-1}^m.
    p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]

    # p_{l-1}^{m+2}.
    p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]

    # p_{l-1}^{m-2}.
    p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]

    # Derivative computation requires negative orders.
    if is_normalized:
        raise NotImplementedError(
            'Negative orders for normalization is not implemented yet.')
    else:
        if num_l > 1:
            l_vec = jnp.arange(1, num_l - 1)
            p_p1 = p[1, 1:num_l - 1, :]
            coeff = -1.0 / ((l_vec + 1) * l_vec)
            update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
            p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1)

        if num_l > 2:
            l_vec = jnp.arange(2, num_l - 1)
            p_p2 = p[2, 2:num_l - 1, :]
            coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
            update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
            p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2)

    m_mat, l_mat = jnp.mgrid[:num_m, :num_l]

    coeff_zeros = jnp.zeros((num_m, num_l))
    upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
    zero_vec = jnp.zeros((num_l, ))

    a0 = -0.5 / (m_mat - 1.0)
    a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
    a0_masked = a0_masked.at[1, :].set(zero_vec)

    b0 = l_mat + m_mat
    c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
    c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
    c0_masked = c0_masked.at[1, :].set(zero_vec)

    # p_l^{m-1}.
    p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
               jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))

    d0 = -0.5 / (m_mat + 1.0)
    d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
    e0 = d0 * b0 * (b0 + 1.0)
    e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])

    # p_l^{m+1}.
    p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
               jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))

    f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
    f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
    p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked,
                              p_mm1_l) - 0.5 * p_mp1_l

    # Special treatment of the singularity at m = 1.
    if num_m > 1:
        l_vec = jnp.arange(num_l)
        g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
        if num_l > 2:
            g0 = g0 - p[2, :, :]
        p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
        p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
        p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x, )))

    return p_derivative