def _vdot_real_part(x, y): """Vector dot-product guaranteed to have a real valued result.""" # all our uses of vdot() in CG are for computing an operator of the form # `z^T M z` where `M` is positive definite and Hermitian, so the result is # real valued: # https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST) result = vdot(x.real, y.real) if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): result += vdot(x.imag, y.imag) return result
def _vdot_real_part(x, y): """Vector dot-product guaranteed to have a real valued result despite possibly complex input. Thus neglects the real-imaginary cross-terms. The result is a real float. """ # all our uses of vdot() in CG are for computing an operator of the form # z^H M z # where M is positive definite and Hermitian, so the result is # real valued: # https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices result = _vdot(x.real, y.real) if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): result += _vdot(x.imag, y.imag) return result
def reim_activation(f, x): sqrt2 = jnp.sqrt(jnp.array(2, dtype=x.real.dtype)) if jnp.iscomplexobj(x): return jax.lax.complex(f(sqrt2 * x.real), f( sqrt2 * x.imag)) / sqrt2 else: return f(x)
def odefun(state: MCState, driver: TDVP, t, w, *, stage=0): # noqa: F811 # pylint: disable=protected-access state.parameters = w state.reset() driver._loss_stats, driver._loss_grad = state.expect_and_grad( driver.generator(t), use_covariance=True, ) driver._loss_grad = jax.tree_map(lambda x: driver._loss_grad_factor * x, driver._loss_grad) qgt = driver.qgt(driver.state) if stage == 0: # TODO: This does not work with FSAL. driver._last_qgt = qgt initial_dw = None if driver.linear_solver_restart else driver._dw driver._dw, _ = qgt.solve(driver.linear_solver, driver._loss_grad, x0=initial_dw) # If parameters are real, then take only real part of the gradient (if it's complex) driver._dw = jax.tree_map( lambda x, target: (x if jnp.iscomplexobj(target) else x.real), driver._dw, state.parameters, ) return driver._dw
def _forward_and_backward(self): """ Performs a number of VMC optimization steps. Args: n_steps (int): Number of steps to perform. """ self.state.reset() # Compute the local energy estimator and average Energy self._loss_stats, self._loss_grad = self.state.expect_and_grad( self._ham) # if it's the identity it does # self._dp = self._loss_grad self._dp = self.preconditioner(self.state, self._loss_grad) # If parameters are real, then take only real part of the gradient (if it's complex) self._dp = jax.tree_map( lambda x, target: (x if jnp.iscomplexobj(target) else x.real), self._dp, self.state.parameters, ) return self._dp
def test_optimization(self, opt_name, opt, target, dtype): if (opt_name in ('fromage', 'noisy_sgd', 'sm3') and jnp.iscomplexobj(dtype)): raise absltest.SkipTest( f'{opt_name} does not support complex parameters.') opt = opt() initial_params, final_params, get_updates = target(dtype) @jax.jit def step(params, state): updates = get_updates(params) if opt_name == 'dpsgd': updates = updates[None] # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_map(lambda x: x.conj(), updates) updates, state = opt.update(updates, state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) for _ in range(10000): params, state = step(params, state) chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
def tree_conj(t: PyTree) -> PyTree: r""" Conjugate all complex leaves. The real leaves are left untouched. Args: t: pytree """ return jax.tree_map(lambda x: jax.lax.conj(x) if jnp.iscomplexobj(x) else x, t)
def _forward_and_backward(self): """ Performs a number of VMC optimization steps. Args: n_steps (int): Number of steps to perform. """ self.state.reset() # Compute the local energy estimator and average Energy self._loss_stats, self._loss_grad = self.state.expect_and_grad( self._ldag_l) if self.sr is not None: self._S = self.state.quantum_geometric_tensor(self.sr) # use the previous solution as an initial guess to speed up the solution of the linear system x0 = self._dp if self.sr_restart is False else None self._dp = self._S.solve(self._loss_grad, x0=x0) else: # tree_map(lambda x, y: x if is_ccomplex(y) else x.real, self._grads, self.state.parameters) self._dp = self._loss_grad # If parameters are real, then take only real part of the gradient (if it's complex) self._dp = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real), self._dp, self.state.parameters, ) return self._dp
def _to_im(x): if jnp.iscomplexobj(x): return x.imag # TODO find a way to make it a nop? # return jax.vmap(lambda y: jnp.array((y.real, y.imag)))(x) else: return None
def grad_expect_hermitian_chunked( chunk_size: int, local_value_kernel_chunked: Callable, model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * mpi.n_nodes O_loc = local_value_kernel_chunked( model_apply_fun, {"params": parameters, **model_state}, σ, local_value_args, chunk_size=chunk_size, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) if mutable is False: vjp_fun_chunked = nkjax.vjp_chunked( lambda w, σ: model_apply_fun({"params": w, **model_state}, σ), parameters, σ, conjugate=True, chunk_size=chunk_size, chunk_argnums=1, nondiff_argnums=1, ) new_model_state = None else: raise NotImplementedError Ō_grad = vjp_fun_chunked( (jnp.conjugate(O_loc) / n_samples), )[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype( target.dtype ), Ō_grad, parameters, ) return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
def tree_conj(t): r""" conjugate all complex leaves The real leaves are left untouched. t: pytree """ return jax.tree_map(lambda x: jax.lax.conj(x) if jnp.iscomplexobj(x) else x, t)
def check(x, target): par_iscomplex = jnp.iscomplexobj(x) # Account for split real-imaginary part in Jacobian*** methods if isinstance(target, tuple): vec_iscomplex = True if len(target) == 2 else False else: vec_iscomplex = jnp.iscomplexobj(target) if not par_iscomplex and vec_iscomplex: raise TypeError( dedent(""" Cannot multiply the (real part of the) QGT by a complex vector. You should either take the real part of the vector, or perform the multiplication against the real and imaginary part of the vector separately and then recomposing the two. """))
def grad_expect_operator_kernel( local_value_kernel: Callable, model_apply_fun: Callable, machine_pow: int, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) is_mutable = mutable is not False logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable ) log_pdf = ( lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real ) def expect_closure_pars(pars): return nkjax.expect( log_pdf, partial(local_value_kernel, logpsi), pars, σ, local_value_args, n_chains=σ_shape[0], ) Ō, Ō_pb, Ō_stats = nkjax.vjp( expect_closure_pars, parameters, has_aux=True, conjugate=True ) Ō_pars_grad = Ō_pb(jnp.ones_like(Ō))[0] # This term below is needed otherwise it does not match the value obtained by # (ha@ha).collect(). I'm unsure of why it is needed. Ō_pars_grad = jax.tree_multimap( lambda x, target: x / 2 if jnp.iscomplexobj(target) else x, Ō_pars_grad, parameters, ) if is_mutable: raise NotImplementedError( "gradient of non-hermitian operators over mutable models " "is not yet implemented." ) new_model_state = None return ( Ō_stats, jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], Ō_pars_grad), new_model_state, )
def grad_expect_hermitian( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * utils.n_nodes O_loc = local_cost_function( local_value_cost, model_apply_fun, {"params": parameters, **model_state}, σp, mels, σ, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) if mutable is False: _, vjp_fun = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ), parameters, conjugate=True, ) new_model_state = None else: _, vjp_fun, new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=True, ) Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), Ō_grad, parameters, ) return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
def logsumexp_cplx(a, b=None, **kwargs): """Compute the log of the sum of exponentials of input elements, always returning a complex number. Equivalent to, but more numerically stable than, `np.log(np.sum(b*np.exp(a)))`. If the optional argument `b` is omitted, `np.log(np.sum(np.exp(a)))` is returned. Wraps `jax.scipy.special.logsumexp` but uses `return_sign=True` if both `a` and `b` are real numbers in order to support `b<0` instead of returning `nan`. See the JAX function for details of the calling sequence; `return_sign` is not supported. """ if jnp.iscomplexobj(a) or jnp.iscomplexobj(b): # logsumexp uses complex algebra anyway return logsumexp(a, b=b, **kwargs) else: a, sgn = logsumexp(a, b=b, **kwargs, return_sign=True) a = a + jnp.where(sgn < 0, 1j * jnp.pi, 0j) return a
def anf(signal, f0, sr, A=1, phi=0, lr=1e-4, device=cpus[0]): if jnp.iscomplexobj(signal): signal = jnp.stack([signal.real, signal.imag], axis=-1) signal = vmap(_anf, in_axes=(-1,) + (None,) * 6, out_axes=(-1,))( signal, f0, sr, A, phi, lr, device ) signal = signal[...,0] + jnp.array(1j) * signal[...,1] else: signal = _anf(signal, f0, sr, A, phi, lr, device) return signal
def test_logsumexp_cplx(a, b): a = jnp.asarray(a) if b is not None: b = jnp.asarray(b) expected = jnp.log( complex(jnp.exp(a[0]) * b[0] + jnp.exp(a[1]) * b[1])) else: expected = jnp.log(complex(jnp.exp(a[0]) + jnp.exp(a[1]))) c = logsumexp_cplx(a, b=b) assert jnp.iscomplexobj(c) assert_allclose(c, expected, atol=1e-8)
def _setup_parabola(dtype): """Quadratic function as an optimization target.""" initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) if jnp.iscomplexobj(dtype): final_params *= 1 + 1j @jax.grad def get_updates(params): return jnp.sum(numerics.abs_sq(params - final_params)) return initial_params, final_params, get_updates
def grad_expect_hermitian( local_value_kernel: Callable, model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * mpi.n_nodes O_loc = local_value_kernel( model_apply_fun, {"params": parameters, **model_state}, σ, local_value_args, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) is_mutable = mutable is not False _, vjp_fun, *new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=is_mutable, ) Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype( target.dtype ), Ō_grad, parameters, ) new_model_state = new_model_state[0] if is_mutable else None return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
def gaussian_random_fill(work_array): """ Fill work_array with random values in place. """ key = jax.random.PRNGKey(int(time.time())) subkey1, subkey2 = jax.random.split(key, 2) output = jax.lax.cond( jnp.iscomplexobj(work_array), (subkey1, work_array), lambda x: gaussian_random_complex_arr(x).astype(work_array.dtype), (subkey2, work_array), lambda x: gaussian_random_real_arr(x).astype(work_array.dtype)) work_array = jax.ops.index_update(work_array, index[:], output) return work_array
def tree_cast(x, target): r""" Cast each leaf of x to the dtype of the corresponding leaf in target. The imaginary part of complex leaves which are cast to real is discarded x: a pytree with arrays as leaves target: a pytree with the same treedef as x where only the dtypes of the leaves are accessed """ # astype alone would also work, however that raises ComplexWarning when casting complex to real # therefore the real is taken first where needed return jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(target.dtype), x, target, )
def _setup_rosenbrock(dtype): """Rosenbrock function as an optimization target.""" a = 1.0 b = 100.0 if jnp.iscomplexobj(dtype): a *= 1 + 1j initial_params = jnp.array([0.0, 0.0], dtype=dtype) final_params = jnp.array([a, a**2], dtype=dtype) @jax.grad def get_updates(params): return (numerics.abs_sq(a - params[0]) + b * numerics.abs_sq(params[1] - params[0]**2)) return initial_params, final_params, get_updates
def apply_fun(params, inputs, **kwargs): # distinguish between real and complex input because # jax is not smart enough to use the good gemm on his own. if jnp.iscomplexobj(inputs): if use_bias: Wr, Wi, br, bi = params return jnp.dot(inputs, Wr + 1j * Wi) + br + 1j * bi else: Wr, Wi = params return jnp.dot(inputs, Wr + 1j * Wi) else: if use_bias: Wr, Wi, br, bi = params return jnp.dot(inputs, Wr) + 1j * jnp.dot(inputs, Wi) + br + 1j * bi else: Wr, Wi = params return jnp.dot(inputs, Wr) + 1j * jnp.dot(inputs, Wi)
def c2r(c): ''' Unpack complex-valued signal into real-valued signal for example, converting [[0.+0.j 1.-1.j] [2.-2.j 3.-3.j] [4.-4.j 5.-5.j] [6.-6.j 7.-7.j]] to [[ 0. 0. 1. -1.] [ 2. -2. 3. -3.] [ 4. -4. 5. -5.] [ 6. -6. 7. -7.]] ''' if jnp.iscomplexobj(c): if c.ndim != 2: raise ValueError('invalid ndim, expected 2 but got %d' % c.ndim) r = jnp.stack([c.real, c.imag], axis=-1).reshape((c.shape[0], -1)) else: r = c return r
def r2c(r): ''' Pack real-valued signal into complex-valued signal for example, converting [[ 0. 0. 1. -1.] [ 2. -2. 3. -3.] [ 4. -4. 5. -5.] [ 6. -6. 7. -7.]] to [[0.+0.j 1.-1.j] [2.-2.j 3.-3.j] [4.-4.j 5.-5.j] [6.-6.j 7.-7.j]] ''' if not jnp.iscomplexobj(r): if r.ndim != 2: raise ValueError('invalid ndim, expected 2 but got %d' % r.ndim) r = r.reshape((r.shape[0], r.shape[-1] // 2, -1)) c = r[..., 0] + 1j * r[..., 1] else: c = r return c
def _exp_grad( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, OΨ: jnp.ndarray, Ψ: jnp.ndarray, ) -> Tuple[PyTree, PyTree]: is_mutable = mutable is not False expval_O = (Ψ.conj() * OΨ).sum() ΔOΨ = (OΨ - expval_O * Ψ.conj()) * Ψ _, vjp_fun, *new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=is_mutable, ) Ō_grad = vjp_fun(ΔOΨ)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype( target.dtype ), Ō_grad, parameters, ) new_model_state = new_model_state[0] if is_mutable else None return ( None, Ō_grad, expval_O, new_model_state, )
def grad_expect_operator_Lrho2( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples_node = σ.shape[0] has_aux = mutable is not False # if not has_aux: # out_axes = (0, 0) # else: # out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] # local_kernel_vmap = jax.vmap( # partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 # ) # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1)) ( Lρ, der_loc_vals, ) = _der_local_values_jax._local_values_and_grads_notcentered_kernel( logpsi, parameters, σp, mels, σ ) # _der_local_values_jax._local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated Lρ = jnp.conjugate(Lρ) LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T) LdagL_mean = LdagL_stats.mean # old implementation # this is faster, even though i think the one below should be faster # (this works, but... yeah. let's keep it here and delete in a while.) grad_fun = jax.vmap(nkjax.grad(logpsi, argnums=0), in_axes=(None, 0), out_axes=0) der_logs = grad_fun(parameters, σ) der_logs_ave = jax.tree_map(lambda x: mean(x, axis=0), der_logs) # TODO # NEW IMPLEMENTATION # This should be faster, but should benchmark as it seems slower # to compute der_logs_ave i can just do a jvp with a ones vector # _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters) # TODO: this ones_like might produce a complexXX type but we only need floatXX # and we cut in 1/2 the # of operations to do. # der_logs_ave = d_logpsi( # jnp.ones_like(_logpsi_ave).real / (n_samples_node * utils.n_nodes) # )[0] der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave) def gradfun(der_loc_vals, der_logs_ave): par_dims = der_loc_vals.ndim - 1 _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims))) grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - ( der_logs_ave.conjugate() * LdagL_mean ) return grad LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave) # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away # the imaginary part of the gradient. # we do this also for standard gradient of energy. # this avoid errors in #867, #789, #850 LdagL_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), LdagL_grad, parameters, ) return ( LdagL_stats, LdagL_grad, model_state, )
def abs2(x): if jnp.iscomplexobj(x): return x.real**2 + x.imag**2 else: return x**2
def toreal(x): if jnp.iscomplexobj(x): return jnp.array([x.real, x.imag]) else: return x
def groundstate( self, chi: int, system_size: Optional[Tuple[int, int]] = None, initial_state: Optional[str] = 'ps', initial_noise: Optional[float] = None, contraction_options: Optional[dict] = None, optimisation_options: Optional[dict] = None, ): """ Computes the groundstate of the model by minimising a trial states energy. Parameters ---------- chi : int The bond-dimension of the (i)PEPS system_size : (int, int), optional for OBC and PBC: the system size, for INFINITE: no effect initial_state : str or jax.numpy.ndarray or PEPS or IPEPS, optional The initial state for the optimisation. If a string keyword: 'ps' : the product state of the respective phase (z+ for g > gc ~ 3.5, x+ for g < gc) `initial_noise>0` recommended, since product-states might have zero-gradient 'z+' : the z+ product state, `initial_noise>0` recommended 'x+' : the x+ product state, `initial_noise>0` recommended 'ipeps' : (only for finite systems) finds the (iPEPS) groundstate of the same model on the infinite lattice, which can directly be used for PBC, or is cut off at the boundary for OBC TODO 'random' keyword If an ndarray: 1D array of the local state, start from the product state of this local state, if it is a z or x eigenstate, `initial_noise=True` recommended If PEPS: (only for finite systems): PEPS of the initial state If IPEPS: (only for infinite systems): iPEPS of the initial state Default: like 'ps' initial_noise : float, optional If `initial_noise > 0`: add a random deviation to the initial_state. then, `initial_noise` is the relative strength of the deviation contraction_options : dict, optional options for the PEPS contraction of the energy-expectationvalue, see `TFIM.energy` optimisation_options : dict, optional options for optimisation. kwargs for `jax_optimise.minimise` Returns ------- gs : PEPS or IPEPS gs_energy : float """ # parse dicts optimisation_options = parse_options(optimisation_options, OPTIMISATION_DEFAULTS_GS) contraction_options = parse_options(contraction_options) # parse system size if self.bc == INFINITE: lx, ly = None, None else: lx, ly = system_size assert lx > 0 assert ly > 0 # parse initial guess initial_guess = _parse_initial_state( initial_state, chi, self.bc, lx, ly, self.g, initial_noise, complex_tensors=np.iscomplexobj(optimisation_options['dtype'])) # define cost_function def cost_function(new_tensors): new_state = initial_guess.with_different_tensors(new_tensors) energy = self.energy(new_state, **contraction_options) return np.reshape(energy, ()) # optimisation optimal_tensors, optimal_energy, info = minimise( cost_function, initial_guess.get_tensors(), **optimisation_options) optimal_state = initial_guess.with_different_tensors(optimal_tensors) return optimal_state, optimal_energy