def quantum_geometric_tensor( self, qgt_T: LinearOperator = QGTAuto()) -> LinearOperator: r"""Computes an estimate of the quantum geometric tensor G_ij. This function returns a linear operator that can be used to apply G_ij to a given vector or can be converted to a full matrix. Args: qgt_T: the optional type of the quantum geometric tensor. By default it's automatically selected. Returns: nk.optimizer.LinearOperator: A linear operator representing the quantum geometric tensor. """ return qgt_T(self)
def __init__( self, operator: AbstractOperator, variational_state: VariationalState, integrator: RKIntegratorConfig, *, t0: float = 0.0, propagation_type="real", qgt: LinearOperator = None, linear_solver=None, linear_solver_restart: bool = False, error_norm: Union[str, Callable] = "euclidean", ): r""" Initializes the time evolution driver. Args: operator: The generator of the dynamics (Hamiltonian for pure states, Lindbladian for density operators). variational_state: The variational state. integrator: Configuration of the algorithm used for solving the ODE. t0: Initial time at the start of the time evolution. propagation_type: Determines the equation of motion: "real" for the real-time Schödinger equation (SE), "imag" for the imaginary-time SE. qgt: The QGT specification. linear_solver: The solver for solving the linear system determining the time evolution. linear_solver_restart: If False (default), the last solution of the linear system is used as initial value in subsequent steps. error_norm: Norm function used to calculate the error with adaptive integrators. Can be either "euclidean" for the standard L2 vector norm :math:`w^\dagger w`, "maximum" for the maximum norm :math:`\max_i |w_i|` or "qgt", in which case the scalar product induced by the QGT :math:`S` is used to compute the norm :math:`\Vert w \Vert^2_S = w^\dagger S w` as suggested in PRL 125, 100503 (2020). Additionally, it possible to pass a custom function with signature :code:`norm(x: PyTree) -> float` which maps a PyTree of parameters :code:`x` to the corresponding norm. Note that norm is used in jax.jit-compiled code. """ self._t0 = t0 if linear_solver is None: linear_solver = nk.optimizer.solver.svd if qgt is None: qgt = QGTAuto(solver=linear_solver) super().__init__(variational_state, optimizer=None, minimized_quantity_name="Generator") self._generator_repr = repr(operator) if isinstance(operator, AbstractOperator): op = operator.collect() self._generator = lambda _: op else: self._generator = operator self.propagation_type = propagation_type if isinstance(variational_state, VariationalMixedState): # assuming Lindblad Dynamics # TODO: support density-matrix imaginary time evolution if propagation_type == "real": self._loss_grad_factor = 1.0 else: raise ValueError( "only real-time Lindblad evolution is supported for " "mixed states") else: if propagation_type == "real": self._loss_grad_factor = -1.0j elif propagation_type == "imag": self._loss_grad_factor = -1.0 else: raise ValueError( "propagation_type must be one of 'real', 'imag'") self.qgt = qgt self.linear_solver = linear_solver self.linear_solver_restart = linear_solver_restart self._dw = None # type: PyTree self._last_qgt = None if isinstance(error_norm, Callable): pass elif error_norm == "euclidean": error_norm = euclidean_norm elif error_norm == "maximum": error_norm = maximum_norm elif error_norm == "qgt": w = self.state.parameters norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w)) # QGT norm is called via host callback since it accesses the driver error_norm = lambda x: hcb.call( HashablePartial(qgt_norm, self), x, result_shape=jax.ShapeDtypeStruct((), norm_dtype), ) else: raise ValueError( "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum'." ) self._odefun = HashablePartial(odefun_host_callback, self.state, self) self._integrator = integrator( self._odefun, t0, self.state.parameters, norm=error_norm, ) self._stop_count = 0