def square_distance(X: jnp.ndarray, X2: jnp.ndarray) -> jnp.ndarray: """ Returns ||X - X2ᵀ||² Due to the implementation and floating-point imprecision, the result may actually be very slightly negative for entries very close to each other. This function can deal with leading dimensions in X and X2. In the sample case, where X and X2 are both 2 dimensional, for example, X is [N, D] and X2 is [M, D], then a tensor of shape [N, M] is returned. If X is [N1, S1, D] and X2 is [N2, S2, D] then the output will be [N1, S1, N2, S2]. """ if X2 is None: Xs = jnp.sum(jnp.square(X), axis=-1, keepdims=True) XT = leading_transpose(X, perm=[..., -1, -2]) dist = -2 * jnp.matmul(X, XT) XsT = leading_transpose(Xs, perm=[..., -1, -2]) conj = jnp.conjugate(XsT) dist += Xs + conj return dist Xs = jnp.sum(jnp.square(X), axis=-1, keepdims=True) X2s = jnp.sum(jnp.square(X2), axis=-1, keepdims=True) X2sT = leading_transpose(X2s, perm=[..., -1, -2]) X2T = leading_transpose(X2, perm=[..., -1, -2]) dist = -2 * jnp.matmul(X, X2T) dist2 = Xs + X2sT # broadcast dist += dist2 return dist
def screened_poisson_multi_kernel(lmbda, im_in,g, k): """computes (lmbda * im_fft + fft(k*im))/(lmbda + k) Args: im (_type_): observation: b,h,w,c g (_type_): prior image: b,h,w,c k (_type_): kernel: b,kh,kw,c """ k_fft = jnp.fft.fft2(k,axes=(1,2),s=im_in.shape[1:3]) #b,h,w,c_in,c_out im_fft = jnp.fft.fft2(im_in,axes=(1,2)) g_fft = jnp.fft.fft2(g,axes=(1,2))#b,h,w,c_in*c_out g_fft = (jnp.conjugate(k_fft) * g_fft).sum(-1,keepdims=True) nom = lmbda * im_fft + g_fft #b,h,w,c_in denom = lmbda + (jnp.conjugate(k_fft) * k_fft).sum(-1,keepdims=True) return jnp.real(jnp.fft.ifft2(nom/denom,axes=(1,2)))
def grad_expect_hermitian_chunked( chunk_size: int, local_value_kernel_chunked: Callable, model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * mpi.n_nodes O_loc = local_value_kernel_chunked( model_apply_fun, {"params": parameters, **model_state}, σ, local_value_args, chunk_size=chunk_size, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) if mutable is False: vjp_fun_chunked = nkjax.vjp_chunked( lambda w, σ: model_apply_fun({"params": w, **model_state}, σ), parameters, σ, conjugate=True, chunk_size=chunk_size, chunk_argnums=1, nondiff_argnums=1, ) new_model_state = None else: raise NotImplementedError Ō_grad = vjp_fun_chunked( (jnp.conjugate(O_loc) / n_samples), )[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype( target.dtype ), Ō_grad, parameters, ) return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
def grad_expect_hermitian( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * utils.n_nodes O_loc = local_cost_function( local_value_cost, model_apply_fun, {"params": parameters, **model_state}, σp, mels, σ, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) if mutable is False: _, vjp_fun = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ), parameters, conjugate=True, ) new_model_state = None else: _, vjp_fun, new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=True, ) Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), Ō_grad, parameters, ) return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
def _local_values_and_grads_notcentered_kernel(logpsi, pars, vp, mel, v): logpsi_vp, f_vjp = nkjax.vjp(lambda w: logpsi(w, vp), pars, conjugate=False) vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v)) # TODO : here someone must bring order to those multiple conjs odtype = jax.eval_shape(logpsi, pars, v).dtype vec = jnp.asarray(jnp.conjugate(vec), dtype=odtype) loc_val = vec.sum() grad_c = f_vjp(vec.conj())[0] return loc_val, grad_c
def dag(state): r"""Returns conjugate transpose of a given state, represented by :math:`A^{\dagger}`, where :math:`A` is a quantum state represented by a ket, a bra or, more generally, a density matrix. Args: state (:obj:`jnp.ndarray`): State to perform the dagger operation on Returns: :obj:`jnp.ndarray`: Conjugate transposed jax.numpy representation of the input state """ return jnp.conjugate(jnp.transpose(state))
def _fidelity_ket(a, b): """Private function that computes fidelity between two kets. Args: a (:obj:`jnp.ndarray`): State vector (ket) b (:obj:`jnp.ndarray`): State vector (ket) Returns: float: fidelity between the two state vectors """ a, b = jnp.asarray(a), jnp.asarray(b) return jnp.abs(jnp.dot(jnp.transpose(jnp.conjugate(a)), b))**2
def loss(t1,flat_p,psi_init,psi0): ''' define the loss function, which is a pure function ''' t_set = jnp.linspace(0., t1, 5) def func(y, t, *args): t1, flat_p, = args return -1.0j*Hmat(t, flat_p, t1)@y res = odeint(func, psi_init, t_set, t1, flat_p, rtol=1.4e-10, atol=1.4e-10) psi_final = res[-1, :] return (1 - jnp.abs(jnp.dot(jnp.conjugate(psi_final), psi0))**2)
def screen_poisson(lambda_d, img,grad_x,grad_y,axes=(-2,-1)): img_freq = jnp.fft.fft2(img,axes=axes) grad_x_freq = jnp.fft.fft2(grad_x,axes=axes) grad_y_freq = jnp.fft.fft2(grad_y,axes=axes) sx = jnp.fft.fftfreq(img.shape[axes[-1]]) sx = jnp.repeat(sx, img.shape[axes[-2]]) sx = jnp.reshape(sx, [img.shape[axes[-1]], img.shape[axes[-2]]]) sx = jnp.transpose(sx) sy = jnp.fft.fftfreq(img.shape[axes[-2]]) sy = jnp.repeat(sy, img.shape[axes[-1]]) sy = jnp.reshape(sy, img.shape) # Fourier transform of shift operators Dx_freq = 2 * math.pi * (jnp.exp(-1j * sx) - 1) Dy_freq = 2 * math.pi * (jnp.exp(-1j * sy) - 1) # my_grad_x_freq = Dx_freq * img_freqs) # my_grad_x_freq & my_grad_y_freq should be the same as grad_x_freq & grad_y_freq recon_freq = (lambda_d * img_freq + jnp.conjugate(Dx_freq) * grad_x_freq + jnp.conjugate(Dy_freq) * grad_y_freq) / \ (lambda_d + (jnp.conjugate(Dx_freq) * Dx_freq + jnp.conjugate(Dy_freq) * Dy_freq)) return jnp.real(jnp.fft.ifft2(recon_freq))
def grad_expect_hermitian( local_value_kernel: Callable, model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * mpi.n_nodes O_loc = local_value_kernel( model_apply_fun, {"params": parameters, **model_state}, σ, local_value_args, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) is_mutable = mutable is not False _, vjp_fun, *new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=is_mutable, ) Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype( target.dtype ), Ō_grad, parameters, ) new_model_state = new_model_state[0] if is_mutable else None return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray, phi: jnp.ndarray, n_max: int) -> jnp.ndarray: """Computes the spherical harmonics.""" cos_colatitude = jnp.cos(phi) legendre = _gen_associated_legendre(n_max, cos_colatitude, True) legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip") angle = abs(m) * theta vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle)) harmonics = lax.complex(legendre_val * jnp.real(vandermonde), legendre_val * jnp.imag(vandermonde)) # Negative order. harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics), harmonics) return harmonics
def harmonics_nonpositive_order(self, harmonics_nonnegative_order=None): """Computes the spherical harmonics of nonpositive orders. With normalization, the nonnegative order Associated Legendre functions are `P_l^{-m}(x) = (−1)^m P_l^m(x)`, which implies that `Y_l^{-m}(θ, φ) = (−1)^m conjugate(Y_l^m(θ, φ))`. Args: harmonics_nonnegative_order: A 4D complex tensor representing the harmonics of nonnegative orders, the shape of which is `(l_max + 1, l_max + 1, num_theta, num_phi)` andd the dimensions are in the sequence of degree, order, colatitude, and longitude. Returns: A 4D complex tensor of the same shape as `harmonics_nonnegative_order` representing the harmonics of nonpositive orders. """ if harmonics_nonnegative_order is None: harmonics_nonnegative_order = self.harmonics_nonnegative_order() mask = self._gen_mask() return jnp.einsum('j,ijkl->ijkl', mask, jnp.conjugate(harmonics_nonnegative_order))
def dag(C): return np.conjugate(C).T
def _transpose(a, perm=None, conjugate=False, name='transpose'): # pylint: disable=unused-argument x = np.transpose(a, perm) return np.conjugate(x) if conjugate else x
def grad_expect_operator_Lrho2( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples_node = σ.shape[0] has_aux = mutable is not False if not has_aux: out_axes = (0, 0) else: out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] local_kernel_vmap = jax.vmap( partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 ) # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1)) ( Lρ, der_loc_vals, ) = netket.operator._der_local_values_jax._local_values_and_grads_notcentered_kernel( logpsi, parameters, σp, mels, σ ) # netket.operator._der_local_values_jax._local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated Lρ = jnp.conjugate(Lρ) LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T) LdagL_mean = LdagL_stats.mean # old implementation # this is faster, even though i think the one below should be faster # (this works, but... yeah. let's keep it here and delete in a while.) grad_fun = jax.vmap(nkjax.grad(logpsi, argnums=0), in_axes=(None, 0), out_axes=0) der_logs = grad_fun(parameters, σ) der_logs_ave = tree_map(lambda x: mean(x, axis=0), der_logs) # TODO # NEW IMPLEMENTATION # This should be faster, but should benchmark as it seems slower # to compute der_logs_ave i can just do a jvp with a ones vector # _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters) # TODO: this ones_like might produce a complexXX type but we only need floatXX # and we cut in 1/2 the # of operations to do. # der_logs_ave = d_logpsi( # jnp.ones_like(_logpsi_ave).real / (n_samples_node * utils.n_nodes) # )[0] der_logs_ave = tree_map(sum_inplace, der_logs_ave) def gradfun(der_loc_vals, der_logs_ave): par_dims = der_loc_vals.ndim - 1 _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims))) grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - ( der_logs_ave.conjugate() * LdagL_mean ) return grad LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave) return ( LdagL_stats, LdagL_grad, model_state, )
def norm(y): return jnp.real(jnp.conjugate(y).dot(y))
def rmsNorm(x): squareNorm = np.sum(x * np.conjugate(x)) size = np.prod(np.array(x.shape)) return np.sqrt(squareNorm / size)
def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False): """ Calculate various forms of windowed FFTs for PSD, CSD, etc. This is a helper function that implements the commonality between the stft, psd, csd, and spectrogram functions. It is not designed to be called externally. The windows are not averaged over; the result from each window is returned. Parameters --------- x : array_like Array or sequence containing the data to be analyzed. y : array_like Array or sequence containing the data to be analyzed. If this is the same object in memory as `x` (i.e. ``_spectral_helper(x, x, ...)``), the extra computations are spared. fs : float, optional Sampling frequency of the time series. Defaults to 1.0. window : str or tuple or array_like, optional Desired window to use. If `window` is a string or tuple, it is passed to `get_window` to generate the window values, which are DFT-even by default. See `get_window` for a list of windows and required parameters. If `window` is array_like it will be used directly as the window and its length must be nperseg. Defaults to a Hann window. nperseg : int, optional Length of each segment. Defaults to None, but if window is str or tuple, is set to 256, and if window is array_like, is set to the length of the window. noverlap : int, optional Number of points to overlap between segments. If `None`, ``noverlap = nperseg // 2``. Defaults to `None`. nfft : int, optional Length of the FFT used, if a zero padded FFT is desired. If `None`, the FFT length is `nperseg`. Defaults to `None`. detrend : str or function or `False`, optional Specifies how to detrend each segment. If `detrend` is a string, it is passed as the `type` argument to the `detrend` function. If it is a function, it takes a segment and returns a detrended segment. If `detrend` is `False`, no detrending is done. Defaults to 'constant'. return_onesided : bool, optional If `True`, return a one-sided spectrum for real data. If `False` return a two-sided spectrum. Defaults to `True`, but for complex data, a two-sided spectrum is always returned. scaling : { 'density', 'spectrum' }, optional Selects between computing the cross spectral density ('density') where `Pxy` has units of V**2/Hz and computing the cross spectrum ('spectrum') where `Pxy` has units of V**2, if `x` and `y` are measured in V and `fs` is measured in Hz. Defaults to 'density' axis : int, optional Axis along which the FFTs are computed; the default is over the last axis (i.e. ``axis=-1``). mode: str {'psd', 'stft'}, optional Defines what kind of return values are expected. Defaults to 'psd'. boundary : str or None, optional Specifies whether the input signal is extended at both ends, and how to generate the new values, in order to center the first windowed segment on the first input point. This has the benefit of enabling reconstruction of the first input point when the employed window function starts at zero. Valid options are ``['even', 'odd', 'constant', 'zeros', None]``. Defaults to `None`. padded : bool, optional Specifies whether the input signal is zero-padded at the end to make the signal fit exactly into an integer number of window segments, so that all of the signal is included in the output. Defaults to `False`. Padding occurs after boundary extension, if `boundary` is not `None`, and `padded` is `True`. Returns ------- freqs : ndarray Array of sample frequencies. t : ndarray Array of times corresponding to each data segment result : ndarray Array of output data, contents dependent on *mode* kwarg. Notes ----- Adapted from matplotlib.mlab .. versionadded:: 0.16.0 """ if mode not in ['psd', 'stft']: raise ValueError("Unknown value for mode %s, must be one of: " "{'psd', 'stft'}" % mode) boundary_funcs = { 'even': even_ext, 'odd': odd_ext, 'constant': const_ext, 'zeros': zero_ext, None: None } if boundary not in boundary_funcs: raise ValueError( "Unknown boundary option '{0}', must be one of: {1}".format( boundary, list(boundary_funcs.keys()))) # If x and y are the same object we can save ourselves some computation. same_data = y is x if not same_data and mode != 'psd': raise ValueError("x and y must be equal if mode is 'stft'") axis = int(axis) # Ensure we have np.arrays, get outdtype x = np.asarray(x) if not same_data: y = np.asarray(y) outdtype = np.result_type(x, y, np.complex64) else: outdtype = np.result_type(x, np.complex64) if not same_data: # Check if we can broadcast the outer axes together xouter = list(x.shape) youter = list(y.shape) xouter.pop(axis) youter.pop(axis) try: outershape = np.broadcast(np.empty(xouter), np.empty(youter)).shape except ValueError: raise ValueError('x and y cannot be broadcast together.') if same_data: if x.size == 0: return np.empty(x.shape), np.empty(x.shape), np.empty(x.shape) else: if x.size == 0 or y.size == 0: outshape = outershape + (min([x.shape[axis], y.shape[axis]]), ) emptyout = np.rollaxis(np.empty(outshape), -1, axis) return emptyout, emptyout, emptyout if x.ndim > 1: if axis != -1: x = np.rollaxis(x, axis, len(x.shape)) if not same_data and y.ndim > 1: y = np.rollaxis(y, axis, len(y.shape)) # Check if x and y are the same length, zero-pad if necessary if not same_data: if x.shape[-1] != y.shape[-1]: if x.shape[-1] < y.shape[-1]: pad_shape = list(x.shape) pad_shape[-1] = y.shape[-1] - x.shape[-1] x = np.concatenate((x, np.zeros(pad_shape)), -1) else: pad_shape = list(y.shape) pad_shape[-1] = x.shape[-1] - y.shape[-1] y = np.concatenate((y, np.zeros(pad_shape)), -1) if nperseg is not None: # if specified by user nperseg = int(nperseg) if nperseg < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape win, nperseg = _triage_segments(window, nperseg, input_length=x.shape[-1]) if nfft is None: nfft = nperseg elif nfft < nperseg: raise ValueError('nfft must be greater than or equal to nperseg.') else: nfft = int(nfft) if noverlap is None: noverlap = nperseg // 2 else: noverlap = int(noverlap) if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Padding occurs after boundary extension, so that the extended signal ends # in zeros, instead of introducing an impulse at the end. # I.e. if x = [..., 3, 2] # extend then pad -> [..., 3, 2, 2, 3, 0, 0, 0] # pad then extend -> [..., 3, 2, 0, 0, 0, 2, 3] if boundary is not None: ext_func = boundary_funcs[boundary] x = ext_func(x, nperseg // 2, axis=-1) if not same_data: y = ext_func(y, nperseg // 2, axis=-1) if padded: # Pad to integer number of windowed segments # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] x = np.concatenate((x, np.zeros(zeros_shape)), axis=-1) if not same_data: zeros_shape = list(y.shape[:-1]) + [nadd] y = np.concatenate((y, np.zeros(zeros_shape)), axis=-1) # Handle detrending and window functions if not detrend: def detrend_func(d): return d elif not hasattr(detrend, '__call__'): raise NotImplementedError() # def detrend_func(d): # return signaltools.detrend(d, type=detrend, axis=-1) elif axis != -1: # Wrap this function so that it receives a shape that it could # reasonably expect to receive. def detrend_func(d): d = np.rollaxis(d, -1, axis) d = detrend(d) return np.rollaxis(d, axis, len(d.shape)) else: detrend_func = detrend if np.result_type(win, np.complex64) != outdtype: win = win.astype(outdtype) if scaling == 'density': scale = 1.0 / (fs * (win * win).sum()) elif scaling == 'spectrum': scale = 1.0 / win.sum()**2 else: raise ValueError('Unknown scaling: %r' % scaling) if mode == 'stft': scale = np.sqrt(scale) if return_onesided: if np.iscomplexobj(x): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'onesided' if not same_data: if np.iscomplexobj(y): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'twosided' if sides == 'twosided': freqs = np.fft.fftfreq(nfft, 1 / fs) elif sides == 'onesided': freqs = np.fft.rfftfreq(nfft, 1 / fs) # Perform the windowed FFTs result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides) if not same_data: # All the same operations on the y data result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft, sides) result = np.conjugate(result) * result_y elif mode == 'psd': result = np.conjugate(result) * result result *= scale if sides == 'onesided' and mode == 'psd': if nfft % 2: result[..., 1:] *= 2 else: # Last point is unpaired Nyquist freq point, don't double result[..., 1:-1] *= 2 time = np.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1, nperseg - noverlap) / float(fs) if boundary is not None: time -= (nperseg / 2) / fs result = result.astype(outdtype) # All imaginary parts are zero anyways if same_data and mode != 'stft': result = result.real # Output is going to have new last axis for time/window index, so a # negative axis index shifts down one if axis < 0: axis -= 1 # Roll frequency axis back to axis where the data came from result = np.rollaxis(result, -1, axis) return freqs, time, result
def conjugate(self): """ Returns the conjugate of a tensor. """ return Tensor(self.dom, self.cod, np.conjugate(self.array))
def J_dag(): J_dag = jnp.transpose(jnp.conjugate(J())) return J_dag
def dagger(self): array = np.moveaxis(self.array, range(len(self.dom @ self.cod)), [ i + len(self.cod) if i < len(self.dom) else i - len(self.dom) for i in range(len(self.dom @ self.cod)) ]) return Tensor(self.cod, self.dom, np.conjugate(array))
def grad_expect_operator_Lrho2( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples_node = σ.shape[0] has_aux = mutable is not False # if not has_aux: # out_axes = (0, 0) # else: # out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] # local_kernel_vmap = jax.vmap( # partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 # ) # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1)) ( Lρ, der_loc_vals, ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ) # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated Lρ = jnp.conjugate(Lρ) LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T) LdagL_mean = LdagL_stats.mean _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters) # TODO: this ones_like might produce a complexXX type but we only need floatXX # and we cut in 1/2 the # of operations to do. der_logs_ave = d_logpsi( jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes) )[0] der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave) def gradfun(der_loc_vals, der_logs_ave): par_dims = der_loc_vals.ndim - 1 _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims))) grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - ( der_logs_ave.conjugate() * LdagL_mean ) return grad LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave) # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away # the imaginary part of the gradient. # we do this also for standard gradient of energy. # this avoid errors in #867, #789, #850 LdagL_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), LdagL_grad, parameters, ) return ( LdagL_stats, LdagL_grad, model_state, )
def overlap_exact(psi1, psi2): ''' take in two state and return the overlap <psi1 | psi2> psi1 is not taken complex conjugate beforehand. ''' return np.dot(np.conjugate(psi1), psi2)
def Cc(x: Array) -> Array: return np.conjugate(x)
def cost(params): state = circuit(params) op_qs = jnp.array([[0], [0.707], [0.707], [0]]) fid = jnp.abs(jnp.dot(jnp.transpose(jnp.conjugate(op_qs)), state))**2 return -jnp.real(fid)[0][0]