Beispiel #1
0
def test_hashable_partial_merges_with_hashable_partial():
    def f(a, b, c):
        pass

    g = HashablePartial(f, 1)
    h = HashablePartial(g, 2)

    assert h.args == (1, 2)
Beispiel #2
0
def test_hashable_partial_merges_with_partial():
    def f(a, b, c, d, e, f, g):
        pass

    g = partial(f, 2, d=3)
    h = partial(g, 4, e=5)
    i = HashablePartial(h, 6, f=7)

    assert i.args == (2, 4, 6)
    assert i.keywords == {"d": 3, "e": 5, "f": 7}

    g2 = partial(f, 2, d=3)
    h2 = partial(g2, 4, e=5)
    i2 = HashablePartial(h2, 6, f=7)

    assert i == i2
Beispiel #3
0
    def error_norm(self, error_norm: Union[str, Callable]):
        if isinstance(error_norm, Callable):
            self._error_norm = error_norm
        elif error_norm == "euclidean":
            self._error_norm = euclidean_norm
        elif error_norm == "maximum":
            self._error_norm = maximum_norm
        elif error_norm == "qgt":
            w = self.state.parameters
            norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w))
            # QGT norm is called via host callback since it accesses the driver
            # TODO: make this also an hashablepartial on self to reduce recompilation
            self._error_norm = lambda x: hcb.call(
                HashablePartial(qgt_norm, self),
                x,
                result_shape=jax.ShapeDtypeStruct((), norm_dtype),
            )
        else:
            raise ValueError(
                "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum',"
                f" but {error_norm} was passed."
            )

        if self.integrator is not None:
            self.integrator.norm = self._error_norm
Beispiel #4
0
def reim(f):
    r"""Modifies a non-linearity to act seperately on the real and imaginary parts"""
    def reim_activation(f, x):
        sqrt2 = jnp.sqrt(jnp.array(2, dtype=x.real.dtype))
        if jnp.iscomplexobj(x):
            return jax.lax.complex(f(sqrt2 * x.real), f(
                sqrt2 * x.imag)) / sqrt2
        else:
            return f(x)

    return HashablePartial(reim_activation, f)
Beispiel #5
0
    def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable:
        """
        Returns a closure with the log_pdf function encoded by this sampler.

        Note: the result is returned as an HashablePartial so that the closure
        does not trigger recompilation.

        Args:
            model: The machine, or apply_fun

        Returns:
            the log probability density function
        """
        apply_fun = get_afun_if_module(model)
        log_pdf = HashablePartial(
            lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).real,
            apply_fun,
        )
        return log_pdf
Beispiel #6
0
    def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable:
        """
        Returns a closure with the log-pdf function encoded by this sampler.

        Args:
            model: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.

        Returns:
            The log-probability density function.

        Note:
            The result is returned as a `HashablePartial` so that the closure
            does not trigger recompilation.
        """
        apply_fun = get_afun_if_module(model)
        log_pdf = HashablePartial(
            lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).real,
            apply_fun,
        )
        return log_pdf
Beispiel #7
0
    def __init__(
        self,
        operator: AbstractOperator,
        variational_state: VariationalState,
        integrator: RKIntegratorConfig,
        *,
        t0: float = 0.0,
        propagation_type="real",
        qgt: LinearOperator = None,
        linear_solver=None,
        linear_solver_restart: bool = False,
        error_norm: Union[str, Callable] = "euclidean",
    ):
        r"""
        Initializes the time evolution driver.

        Args:
            operator: The generator of the dynamics (Hamiltonian for pure states,
                Lindbladian for density operators).
            variational_state: The variational state.
            integrator: Configuration of the algorithm used for solving the ODE.
            t0: Initial time at the start of the time evolution.
            propagation_type: Determines the equation of motion: "real"  for the
                real-time Schödinger equation (SE), "imag" for the imaginary-time SE.
            qgt: The QGT specification.
            linear_solver: The solver for solving the linear system determining the time evolution.
            linear_solver_restart: If False (default), the last solution of the linear system
                is used as initial value in subsequent steps.
            error_norm: Norm function used to calculate the error with adaptive integrators.
                Can be either "euclidean" for the standard L2 vector norm :math:`w^\dagger w`,
                "maximum" for the maximum norm :math:`\max_i |w_i|`
                or "qgt", in which case the scalar product induced by the QGT :math:`S` is used
                to compute the norm :math:`\Vert w \Vert^2_S = w^\dagger S w` as suggested
                in PRL 125, 100503 (2020).
                Additionally, it possible to pass a custom function with signature
                :code:`norm(x: PyTree) -> float`
                which maps a PyTree of parameters :code:`x` to the corresponding norm.
                Note that norm is used in jax.jit-compiled code.
        """
        self._t0 = t0

        if linear_solver is None:
            linear_solver = nk.optimizer.solver.svd
        if qgt is None:
            qgt = QGTAuto(solver=linear_solver)

        super().__init__(variational_state,
                         optimizer=None,
                         minimized_quantity_name="Generator")

        self._generator_repr = repr(operator)
        if isinstance(operator, AbstractOperator):
            op = operator.collect()
            self._generator = lambda _: op
        else:
            self._generator = operator

        self.propagation_type = propagation_type
        if isinstance(variational_state, VariationalMixedState):
            # assuming Lindblad Dynamics
            # TODO: support density-matrix imaginary time evolution
            if propagation_type == "real":
                self._loss_grad_factor = 1.0
            else:
                raise ValueError(
                    "only real-time Lindblad evolution is supported for "
                    "mixed states")
        else:
            if propagation_type == "real":
                self._loss_grad_factor = -1.0j
            elif propagation_type == "imag":
                self._loss_grad_factor = -1.0
            else:
                raise ValueError(
                    "propagation_type must be one of 'real', 'imag'")

        self.qgt = qgt
        self.linear_solver = linear_solver
        self.linear_solver_restart = linear_solver_restart

        self._dw = None  # type: PyTree
        self._last_qgt = None

        if isinstance(error_norm, Callable):
            pass
        elif error_norm == "euclidean":
            error_norm = euclidean_norm
        elif error_norm == "maximum":
            error_norm = maximum_norm
        elif error_norm == "qgt":
            w = self.state.parameters
            norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w))
            # QGT norm is called via host callback since it accesses the driver
            error_norm = lambda x: hcb.call(
                HashablePartial(qgt_norm, self),
                x,
                result_shape=jax.ShapeDtypeStruct((), norm_dtype),
            )
        else:
            raise ValueError(
                "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum'."
            )

        self._odefun = HashablePartial(odefun_host_callback, self.state, self)
        self._integrator = integrator(
            self._odefun,
            t0,
            self.state.parameters,
            norm=error_norm,
        )

        self._stop_count = 0