Пример #1
0
def square_distance(X: jnp.ndarray, X2: jnp.ndarray) -> jnp.ndarray:
    """
    Returns ||X - X2ᵀ||²
    Due to the implementation and floating-point imprecision, the
    result may actually be very slightly negative for entries very
    close to each other.
    This function can deal with leading dimensions in X and X2.
    In the sample case, where X and X2 are both 2 dimensional,
    for example, X is [N, D] and X2 is [M, D], then a tensor of shape
    [N, M] is returned. If X is [N1, S1, D] and X2 is [N2, S2, D]
    then the output will be [N1, S1, N2, S2].
    """
    if X2 is None:
        Xs = jnp.sum(jnp.square(X), axis=-1, keepdims=True)
        XT = leading_transpose(X, perm=[..., -1, -2])
        dist = -2 * jnp.matmul(X, XT)
        XsT = leading_transpose(Xs, perm=[..., -1, -2])
        conj = jnp.conjugate(XsT)
        dist += Xs + conj
        return dist
    Xs = jnp.sum(jnp.square(X), axis=-1, keepdims=True)
    X2s = jnp.sum(jnp.square(X2), axis=-1, keepdims=True)
    X2sT = leading_transpose(X2s, perm=[..., -1, -2])
    X2T = leading_transpose(X2, perm=[..., -1, -2])
    dist = -2 * jnp.matmul(X, X2T)
    dist2 = Xs + X2sT  # broadcast
    dist += dist2
    return dist
Пример #2
0
def screened_poisson_multi_kernel(lmbda, im_in,g, k):
    """computes (lmbda * im_fft + fft(k*im))/(lmbda + k)

    Args:
        im (_type_): observation: b,h,w,c
        g (_type_): prior image: b,h,w,c
        k (_type_): kernel: b,kh,kw,c
    """
    k_fft = jnp.fft.fft2(k,axes=(1,2),s=im_in.shape[1:3]) #b,h,w,c_in,c_out
    im_fft = jnp.fft.fft2(im_in,axes=(1,2))
    g_fft = jnp.fft.fft2(g,axes=(1,2))#b,h,w,c_in*c_out
    g_fft = (jnp.conjugate(k_fft) * g_fft).sum(-1,keepdims=True)
    
    nom = lmbda * im_fft + g_fft #b,h,w,c_in
    denom = lmbda + (jnp.conjugate(k_fft) * k_fft).sum(-1,keepdims=True)
    return jnp.real(jnp.fft.ifft2(nom/denom,axes=(1,2)))
Пример #3
0
def grad_expect_hermitian_chunked(
    chunk_size: int,
    local_value_kernel_chunked: Callable,
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_value_kernel_chunked(
        model_apply_fun,
        {"params": parameters, **model_state},
        σ,
        local_value_args,
        chunk_size=chunk_size,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        vjp_fun_chunked = nkjax.vjp_chunked(
            lambda w, σ: model_apply_fun({"params": w, **model_state}, σ),
            parameters,
            σ,
            conjugate=True,
            chunk_size=chunk_size,
            chunk_argnums=1,
            nondiff_argnums=1,
        )
        new_model_state = None
    else:
        raise NotImplementedError

    Ō_grad = vjp_fun_chunked(
        (jnp.conjugate(O_loc) / n_samples),
    )[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
Пример #4
0
def grad_expect_hermitian(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * utils.n_nodes

    O_loc = local_cost_function(
        local_value_cost,
        model_apply_fun,
        {"params": parameters, **model_state},
        σp,
        mels,
        σ,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        _, vjp_fun = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ),
            parameters,
            conjugate=True,
        )
        new_model_state = None
    else:
        _, vjp_fun, new_model_state = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
            parameters,
            conjugate=True,
            has_aux=True,
        )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
Пример #5
0
def _local_values_and_grads_notcentered_kernel(logpsi, pars, vp, mel, v):
    logpsi_vp, f_vjp = nkjax.vjp(lambda w: logpsi(w, vp), pars, conjugate=False)
    vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v))

    # TODO : here someone must bring order to those multiple conjs
    odtype = jax.eval_shape(logpsi, pars, v).dtype
    vec = jnp.asarray(jnp.conjugate(vec), dtype=odtype)
    loc_val = vec.sum()
    grad_c = f_vjp(vec.conj())[0]

    return loc_val, grad_c
Пример #6
0
def dag(state):
    r"""Returns conjugate transpose of a given state, represented by :math:`A^{\dagger}`, where :math:`A` is
    a quantum state represented by a ket, a bra or, more generally, a density matrix.

    Args:
        state (:obj:`jnp.ndarray`): State to perform the dagger operation on
     
    Returns:
        :obj:`jnp.ndarray`: Conjugate transposed jax.numpy representation of the input state
 
    """
    return jnp.conjugate(jnp.transpose(state))
Пример #7
0
def _fidelity_ket(a, b):
    """Private function that computes fidelity between two kets.
    
    Args:
        a (:obj:`jnp.ndarray`): State vector (ket)
        b (:obj:`jnp.ndarray`): State vector (ket) 
        
    Returns:
        float: fidelity between the two state vectors
    """
    a, b = jnp.asarray(a), jnp.asarray(b)
    return jnp.abs(jnp.dot(jnp.transpose(jnp.conjugate(a)), b))**2
Пример #8
0
def loss(t1,flat_p,psi_init,psi0):
    '''
    define the loss function, which is a pure function
    '''
    t_set = jnp.linspace(0., t1, 5)

    def func(y, t, *args):
        t1, flat_p, = args
        return -1.0j*Hmat(t, flat_p, t1)@y

    res = odeint(func, psi_init, t_set, t1, flat_p, rtol=1.4e-10, atol=1.4e-10)
    psi_final = res[-1, :]

    return (1 - jnp.abs(jnp.dot(jnp.conjugate(psi_final), psi0))**2)
Пример #9
0
def screen_poisson(lambda_d, img,grad_x,grad_y,axes=(-2,-1)):
    img_freq = jnp.fft.fft2(img,axes=axes)
    grad_x_freq = jnp.fft.fft2(grad_x,axes=axes)
    grad_y_freq = jnp.fft.fft2(grad_y,axes=axes)

    sx = jnp.fft.fftfreq(img.shape[axes[-1]])
    sx = jnp.repeat(sx, img.shape[axes[-2]])
    sx = jnp.reshape(sx, [img.shape[axes[-1]], img.shape[axes[-2]]])
    sx = jnp.transpose(sx)
    sy = jnp.fft.fftfreq(img.shape[axes[-2]])
    sy = jnp.repeat(sy, img.shape[axes[-1]])
    sy = jnp.reshape(sy, img.shape)

    # Fourier transform of shift operators
    Dx_freq = 2 * math.pi * (jnp.exp(-1j * sx) - 1)
    Dy_freq = 2 * math.pi * (jnp.exp(-1j * sy) - 1)

    # my_grad_x_freq = Dx_freq * img_freqs)
    # my_grad_x_freq & my_grad_y_freq should be the same as grad_x_freq & grad_y_freq

    recon_freq = (lambda_d * img_freq + jnp.conjugate(Dx_freq) * grad_x_freq + jnp.conjugate(Dy_freq) * grad_y_freq) / \
                (lambda_d + (jnp.conjugate(Dx_freq) * Dx_freq + jnp.conjugate(Dy_freq) * Dy_freq))
    return jnp.real(jnp.fft.ifft2(recon_freq))
Пример #10
0
def grad_expect_hermitian(
    local_value_kernel: Callable,
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_value_kernel(
        model_apply_fun,
        {"params": parameters, **model_state},
        σ,
        local_value_args,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    is_mutable = mutable is not False
    _, vjp_fun, *new_model_state = nkjax.vjp(
        lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
        parameters,
        conjugate=True,
        has_aux=is_mutable,
    )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    new_model_state = new_model_state[0] if is_mutable else None

    return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
Пример #11
0
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray,
              phi: jnp.ndarray, n_max: int) -> jnp.ndarray:
    """Computes the spherical harmonics."""

    cos_colatitude = jnp.cos(phi)

    legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
    legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")

    angle = abs(m) * theta
    vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
    harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
                            legendre_val * jnp.imag(vandermonde))

    # Negative order.
    harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics),
                          harmonics)

    return harmonics
    def harmonics_nonpositive_order(self, harmonics_nonnegative_order=None):
        """Computes the spherical harmonics of nonpositive orders.

    With normalization, the nonnegative order Associated Legendre functions are
    `P_l^{-m}(x) = (−1)^m P_l^m(x)`, which implies that
    `Y_l^{-m}(θ, φ) = (−1)^m conjugate(Y_l^m(θ, φ))`.

    Args:
      harmonics_nonnegative_order: A 4D complex tensor representing the
        harmonics of nonnegative orders, the shape of which is
        `(l_max + 1, l_max + 1, num_theta, num_phi)` andd the dimensions are in
        the sequence of degree, order, colatitude, and longitude.
    Returns:
      A 4D complex tensor of the same shape as `harmonics_nonnegative_order`
      representing the harmonics of nonpositive orders.
    """
        if harmonics_nonnegative_order is None:
            harmonics_nonnegative_order = self.harmonics_nonnegative_order()

        mask = self._gen_mask()
        return jnp.einsum('j,ijkl->ijkl', mask,
                          jnp.conjugate(harmonics_nonnegative_order))
Пример #13
0
def dag(C):
    return np.conjugate(C).T
Пример #14
0
def _transpose(a, perm=None, conjugate=False, name='transpose'):  # pylint: disable=unused-argument
    x = np.transpose(a, perm)
    return np.conjugate(x) if conjugate else x
Пример #15
0
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    if not has_aux:
        out_axes = (0, 0)
    else:
        out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    local_kernel_vmap = jax.vmap(
        partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    (
        Lρ,
        der_loc_vals,
    ) = netket.operator._der_local_values_jax._local_values_and_grads_notcentered_kernel(
        logpsi, parameters, σp, mels, σ
    )
    # netket.operator._der_local_values_jax._local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    # old implementation
    # this is faster, even though i think the one below should be faster
    # (this works, but... yeah. let's keep it here and delete in a while.)
    grad_fun = jax.vmap(nkjax.grad(logpsi, argnums=0), in_axes=(None, 0), out_axes=0)
    der_logs = grad_fun(parameters, σ)
    der_logs_ave = tree_map(lambda x: mean(x, axis=0), der_logs)

    # TODO
    # NEW IMPLEMENTATION
    # This should be faster, but should benchmark as it seems slower
    # to compute der_logs_ave i can just do a jvp with a ones vector
    # _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    # der_logs_ave = d_logpsi(
    #    jnp.ones_like(_logpsi_ave).real / (n_samples_node * utils.n_nodes)
    # )[0]
    der_logs_ave = tree_map(sum_inplace, der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        )
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    return (
        LdagL_stats,
        LdagL_grad,
        model_state,
    )
Пример #16
0
 def norm(y):
     return jnp.real(jnp.conjugate(y).dot(y))
Пример #17
0
def rmsNorm(x):
    squareNorm = np.sum(x * np.conjugate(x))
    size = np.prod(np.array(x.shape))
    return np.sqrt(squareNorm / size)
Пример #18
0
def _spectral_helper(x,
                     y,
                     fs=1.0,
                     window='hann',
                     nperseg=None,
                     noverlap=None,
                     nfft=None,
                     detrend='constant',
                     return_onesided=True,
                     scaling='density',
                     axis=-1,
                     mode='psd',
                     boundary=None,
                     padded=False):
    """
    Calculate various forms of windowed FFTs for PSD, CSD, etc.

    This is a helper function that implements the commonality between
    the stft, psd, csd, and spectrogram functions. It is not designed to
    be called externally. The windows are not averaged over; the result
    from each window is returned.

    Parameters
    ---------
    x : array_like
        Array or sequence containing the data to be analyzed.
    y : array_like
        Array or sequence containing the data to be analyzed. If this is
        the same object in memory as `x` (i.e. ``_spectral_helper(x,
        x, ...)``), the extra computations are spared.
    fs : float, optional
        Sampling frequency of the time series. Defaults to 1.0.
    window : str or tuple or array_like, optional
        Desired window to use. If `window` is a string or tuple, it is
        passed to `get_window` to generate the window values, which are
        DFT-even by default. See `get_window` for a list of windows and
        required parameters. If `window` is array_like it will be used
        directly as the window and its length must be nperseg. Defaults
        to a Hann window.
    nperseg : int, optional
        Length of each segment. Defaults to None, but if window is str or
        tuple, is set to 256, and if window is array_like, is set to the
        length of the window.
    noverlap : int, optional
        Number of points to overlap between segments. If `None`,
        ``noverlap = nperseg // 2``. Defaults to `None`.
    nfft : int, optional
        Length of the FFT used, if a zero padded FFT is desired. If
        `None`, the FFT length is `nperseg`. Defaults to `None`.
    detrend : str or function or `False`, optional
        Specifies how to detrend each segment. If `detrend` is a
        string, it is passed as the `type` argument to the `detrend`
        function. If it is a function, it takes a segment and returns a
        detrended segment. If `detrend` is `False`, no detrending is
        done. Defaults to 'constant'.
    return_onesided : bool, optional
        If `True`, return a one-sided spectrum for real data. If
        `False` return a two-sided spectrum. Defaults to `True`, but for
        complex data, a two-sided spectrum is always returned.
    scaling : { 'density', 'spectrum' }, optional
        Selects between computing the cross spectral density ('density')
        where `Pxy` has units of V**2/Hz and computing the cross
        spectrum ('spectrum') where `Pxy` has units of V**2, if `x`
        and `y` are measured in V and `fs` is measured in Hz.
        Defaults to 'density'
    axis : int, optional
        Axis along which the FFTs are computed; the default is over the
        last axis (i.e. ``axis=-1``).
    mode: str {'psd', 'stft'}, optional
        Defines what kind of return values are expected. Defaults to
        'psd'.
    boundary : str or None, optional
        Specifies whether the input signal is extended at both ends, and
        how to generate the new values, in order to center the first
        windowed segment on the first input point. This has the benefit
        of enabling reconstruction of the first input point when the
        employed window function starts at zero. Valid options are
        ``['even', 'odd', 'constant', 'zeros', None]``. Defaults to
        `None`.
    padded : bool, optional
        Specifies whether the input signal is zero-padded at the end to
        make the signal fit exactly into an integer number of window
        segments, so that all of the signal is included in the output.
        Defaults to `False`. Padding occurs after boundary extension, if
        `boundary` is not `None`, and `padded` is `True`.
    Returns
    -------
    freqs : ndarray
        Array of sample frequencies.
    t : ndarray
        Array of times corresponding to each data segment
    result : ndarray
        Array of output data, contents dependent on *mode* kwarg.

    Notes
    -----
    Adapted from matplotlib.mlab

    .. versionadded:: 0.16.0
    """
    if mode not in ['psd', 'stft']:
        raise ValueError("Unknown value for mode %s, must be one of: "
                         "{'psd', 'stft'}" % mode)

    boundary_funcs = {
        'even': even_ext,
        'odd': odd_ext,
        'constant': const_ext,
        'zeros': zero_ext,
        None: None
    }

    if boundary not in boundary_funcs:
        raise ValueError(
            "Unknown boundary option '{0}', must be one of: {1}".format(
                boundary, list(boundary_funcs.keys())))

    # If x and y are the same object we can save ourselves some computation.
    same_data = y is x

    if not same_data and mode != 'psd':
        raise ValueError("x and y must be equal if mode is 'stft'")

    axis = int(axis)

    # Ensure we have np.arrays, get outdtype
    x = np.asarray(x)
    if not same_data:
        y = np.asarray(y)
        outdtype = np.result_type(x, y, np.complex64)
    else:
        outdtype = np.result_type(x, np.complex64)

    if not same_data:
        # Check if we can broadcast the outer axes together
        xouter = list(x.shape)
        youter = list(y.shape)
        xouter.pop(axis)
        youter.pop(axis)
        try:
            outershape = np.broadcast(np.empty(xouter), np.empty(youter)).shape
        except ValueError:
            raise ValueError('x and y cannot be broadcast together.')

    if same_data:
        if x.size == 0:
            return np.empty(x.shape), np.empty(x.shape), np.empty(x.shape)
    else:
        if x.size == 0 or y.size == 0:
            outshape = outershape + (min([x.shape[axis], y.shape[axis]]), )
            emptyout = np.rollaxis(np.empty(outshape), -1, axis)
            return emptyout, emptyout, emptyout

    if x.ndim > 1:
        if axis != -1:
            x = np.rollaxis(x, axis, len(x.shape))
            if not same_data and y.ndim > 1:
                y = np.rollaxis(y, axis, len(y.shape))

    # Check if x and y are the same length, zero-pad if necessary
    if not same_data:
        if x.shape[-1] != y.shape[-1]:
            if x.shape[-1] < y.shape[-1]:
                pad_shape = list(x.shape)
                pad_shape[-1] = y.shape[-1] - x.shape[-1]
                x = np.concatenate((x, np.zeros(pad_shape)), -1)
            else:
                pad_shape = list(y.shape)
                pad_shape[-1] = x.shape[-1] - y.shape[-1]
                y = np.concatenate((y, np.zeros(pad_shape)), -1)

    if nperseg is not None:  # if specified by user
        nperseg = int(nperseg)
        if nperseg < 1:
            raise ValueError('nperseg must be a positive integer')

    # parse window; if array like, then set nperseg = win.shape
    win, nperseg = _triage_segments(window, nperseg, input_length=x.shape[-1])

    if nfft is None:
        nfft = nperseg
    elif nfft < nperseg:
        raise ValueError('nfft must be greater than or equal to nperseg.')
    else:
        nfft = int(nfft)

    if noverlap is None:
        noverlap = nperseg // 2
    else:
        noverlap = int(noverlap)
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Padding occurs after boundary extension, so that the extended signal ends
    # in zeros, instead of introducing an impulse at the end.
    # I.e. if x = [..., 3, 2]
    # extend then pad -> [..., 3, 2, 2, 3, 0, 0, 0]
    # pad then extend -> [..., 3, 2, 0, 0, 0, 2, 3]

    if boundary is not None:
        ext_func = boundary_funcs[boundary]
        x = ext_func(x, nperseg // 2, axis=-1)
        if not same_data:
            y = ext_func(y, nperseg // 2, axis=-1)

    if padded:
        # Pad to integer number of windowed segments
        # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
        nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg
        zeros_shape = list(x.shape[:-1]) + [nadd]
        x = np.concatenate((x, np.zeros(zeros_shape)), axis=-1)
        if not same_data:
            zeros_shape = list(y.shape[:-1]) + [nadd]
            y = np.concatenate((y, np.zeros(zeros_shape)), axis=-1)

    # Handle detrending and window functions
    if not detrend:

        def detrend_func(d):
            return d
    elif not hasattr(detrend, '__call__'):
        raise NotImplementedError()
        # def detrend_func(d):
        #     return signaltools.detrend(d, type=detrend, axis=-1)
    elif axis != -1:
        # Wrap this function so that it receives a shape that it could
        # reasonably expect to receive.
        def detrend_func(d):
            d = np.rollaxis(d, -1, axis)
            d = detrend(d)
            return np.rollaxis(d, axis, len(d.shape))
    else:
        detrend_func = detrend

    if np.result_type(win, np.complex64) != outdtype:
        win = win.astype(outdtype)

    if scaling == 'density':
        scale = 1.0 / (fs * (win * win).sum())
    elif scaling == 'spectrum':
        scale = 1.0 / win.sum()**2
    else:
        raise ValueError('Unknown scaling: %r' % scaling)

    if mode == 'stft':
        scale = np.sqrt(scale)

    if return_onesided:
        if np.iscomplexobj(x):
            sides = 'twosided'
            warnings.warn('Input data is complex, switching to '
                          'return_onesided=False')
        else:
            sides = 'onesided'
            if not same_data:
                if np.iscomplexobj(y):
                    sides = 'twosided'
                    warnings.warn('Input data is complex, switching to '
                                  'return_onesided=False')
    else:
        sides = 'twosided'

    if sides == 'twosided':
        freqs = np.fft.fftfreq(nfft, 1 / fs)
    elif sides == 'onesided':
        freqs = np.fft.rfftfreq(nfft, 1 / fs)

    # Perform the windowed FFTs
    result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides)

    if not same_data:
        # All the same operations on the y data
        result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft,
                               sides)
        result = np.conjugate(result) * result_y
    elif mode == 'psd':
        result = np.conjugate(result) * result

    result *= scale
    if sides == 'onesided' and mode == 'psd':
        if nfft % 2:
            result[..., 1:] *= 2
        else:
            # Last point is unpaired Nyquist freq point, don't double
            result[..., 1:-1] *= 2

    time = np.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1,
                     nperseg - noverlap) / float(fs)
    if boundary is not None:
        time -= (nperseg / 2) / fs

    result = result.astype(outdtype)

    # All imaginary parts are zero anyways
    if same_data and mode != 'stft':
        result = result.real

    # Output is going to have new last axis for time/window index, so a
    # negative axis index shifts down one
    if axis < 0:
        axis -= 1

    # Roll frequency axis back to axis where the data came from
    result = np.rollaxis(result, -1, axis)

    return freqs, time, result
Пример #19
0
 def conjugate(self):
     """ Returns the conjugate of a tensor. """
     return Tensor(self.dom, self.cod, np.conjugate(self.array))
def J_dag():
    J_dag = jnp.transpose(jnp.conjugate(J()))
    return J_dag
Пример #21
0
 def dagger(self):
     array = np.moveaxis(self.array, range(len(self.dom @ self.cod)), [
         i + len(self.cod) if i < len(self.dom) else i - len(self.dom)
         for i in range(len(self.dom @ self.cod))
     ])
     return Tensor(self.cod, self.dom, np.conjugate(array))
Пример #22
0
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    # if not has_aux:
    #    out_axes = (0, 0)
    # else:
    #    out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    # local_kernel_vmap = jax.vmap(
    #    partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    # )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    (
        Lρ,
        der_loc_vals,
    ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ)
    # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    der_logs_ave = d_logpsi(
        jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes)
    )[0]
    der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        )
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away
    # the imaginary part of the gradient.
    # we do this also for standard gradient of energy.
    # this avoid errors in #867, #789, #850
    LdagL_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        LdagL_grad,
        parameters,
    )

    return (
        LdagL_stats,
        LdagL_grad,
        model_state,
    )
Пример #23
0
def overlap_exact(psi1, psi2):
    '''
    take in two state and return the overlap <psi1 | psi2>
    psi1 is not taken complex conjugate beforehand.
    '''
    return np.dot(np.conjugate(psi1), psi2)
Пример #24
0
def Cc(x: Array) -> Array:
    return np.conjugate(x)
def cost(params):
    state = circuit(params)
    op_qs = jnp.array([[0], [0.707], [0.707], [0]])
    fid = jnp.abs(jnp.dot(jnp.transpose(jnp.conjugate(op_qs)), state))**2
    return -jnp.real(fid)[0][0]