Ejemplo n.º 1
0
 def solve_subproblem(
     self,
     A: SparseCooMatrix,
     ATb: hints.Array,
     lambd: hints.Scalar,
     iteration: hints.Scalar,  # Unused
 ) -> jnp.ndarray:
     # JAX-compatible sparse Cholesky factorization with a host callback. Similar to:
     #     self._solve(_LinearSolverArgs(A, ATb, lambd))
     return hcb.call(self._solve, _LinearSolverArgs(A, ATb, lambd), result_shape=ATb)
Ejemplo n.º 2
0
Archivo: spin.py Proyecto: vlpap/netket
def random_state_batch_spin_impl(hilb: Spin, key, batches, dtype):
    S = hilb._s
    shape = (batches, hilb.size)

    # If unconstrained space, use fast sampling
    if hilb._total_sz is None:
        n_states = int(2 * S + 1)
        rs = jax.random.randint(key, shape=shape, minval=0, maxval=n_states)

        two = jnp.asarray(2, dtype=dtype)
        return jnp.asarray(rs * two - (n_states - 1), dtype=dtype)
    else:
        N = hilb.size
        n_states = int(2 * S) + 1
        # if constrained and S == 1/2, use a trick to sample quickly
        if n_states == 2:
            m = hilb._total_sz * 2
            nup = (N + m) // 2
            ndown = (N - m) // 2

            x = jnp.concatenate(
                (
                    jnp.ones((batches, nup), dtype=dtype),
                    -jnp.ones(
                        (
                            batches,
                            ndown,
                        ),
                        dtype=dtype,
                    ),
                ),
                axis=1,
            )

            # deprecated: return jax.random.shuffle(key, x, axis=1)
            return jax.vmap(jax.random.permutation)(
                jax.random.split(key, x.shape[0]), x
            )

        # if constrained and S != 1/2, then use a slow fallback algorithm
        # TODO: find better, faster way to smaple constrained arbitrary spaces.
        else:
            from jax.experimental import host_callback as hcb

            cb = lambda rng: _random_states_with_constraint(hilb, rng, batches, dtype)

            state = hcb.call(
                cb,
                key,
                result_shape=jax.ShapeDtypeStruct(shape, dtype),
            )

            return state
Ejemplo n.º 3
0
def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape):
    """The simplest implementation of calling to TF, without AD support.

  We must use hcb.call because the TF invocation must happen outside the
  JAX staged computation."""
    def tf_to_numpy(t):
        # Turn the Tensor to NumPy array without copying.
        return np.asarray(memoryview(t)) if isinstance(t, tf.Tensor) else t

    return hcb.call(
        lambda arg: tf.nest.map_structure(tf_to_numpy, tf_fun(arg)),
        arg,
        result_shape=result_shape)
Ejemplo n.º 4
0
        def wrapped_exec_bwd(params, g):
            def jacobian(params):
                tape = self.copy()
                tape.set_parameters(params)
                return tape.jacobian(device,
                                     params=params,
                                     **tape.jacobian_options)

            val = g.reshape((-1, )) * host_callback.call(
                jacobian,
                params,
                result_shape=jax.ShapeDtypeStruct(
                    (1, len(params)), JAXInterface.dtype),
            )
            return (list(val.reshape((-1, ))), )  # Comma is on purpose.
Ejemplo n.º 5
0
def odefun_host_callback(state, driver, *args, **kwargs):
    """
    Calls odefun through a host callback in order to make the rest of the
    ODE solver jit-able.
    """
    result_shape = jax.tree_map(
        lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
        state.parameters,
    )

    return hcb.call(
        lambda args_and_kw: odefun(state, driver, *args_and_kw[0], **args_and_kw[1]),
        # pack args and kwargs together, since host_callback passes a single argument:
        (args, kwargs),
        result_shape=result_shape,
    )
Ejemplo n.º 6
0
def random_state(hilb: Fock, key, batches: int, *, dtype=np.float32):
    shape = (batches, hilb.size)

    # If unconstrained space, use fast sampling
    if hilb.n_particles is None:
        rs = jax.random.randint(key, shape=shape, minval=0, maxval=hilb.n_max + 1)
        return jnp.asarray(rs, dtype=dtype)

    else:
        from jax.experimental import host_callback as hcb

        state = hcb.call(
            lambda rng: _random_states_with_constraint(hilb, rng, batches, dtype),
            key,
            result_shape=jax.ShapeDtypeStruct(shape, dtype),
        )

        return state
Ejemplo n.º 7
0
    def wrapped_exec(params):
        def wrapper(p):
            """Compute the forward pass."""
            new_tapes = []

            for t, a in zip(tapes, p):
                new_tapes.append(t.copy(copy_operations=True))
                new_tapes[-1].set_parameters(a)

            with qml.tape.Unwrap(*new_tapes):
                res, _ = execute_fn(new_tapes, **gradient_kwargs)

            return res

        shapes = [
            jax.ShapeDtypeStruct((1, ), dtype) for _ in range(total_size)
        ]
        res = host_callback.call(wrapper, params, result_shape=shapes)
        return res
Ejemplo n.º 8
0
    def _sample_chain(
        sampler,
        machine: nn.Module,
        parameters: PyTree,
        state: SamplerState,
        chain_length: int,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        # Reimplement sample_chain because we can sample the whole 'chain' in one
        # go, since it's not really a chain anyway. This will be much faster because
        # we call into python only once.
        new_rng, rng = jax.random.split(state.rng)
        numbers = jax.random.choice(
            rng,
            sampler.hilbert.n_states,
            shape=(chain_length * sampler.n_chains_per_rank,),
            replace=True,
            p=state.pdf,
        )

        # We use a host-callback to convert integers labelling states to
        # valid state-arrays because that code is written with numba and
        # we have not yet converted it to jax.
        #
        # For future investigators:
        # this will lead to a crash if numbers_to_state throws.
        # it throws if we feed it nans!
        samples = hcb.call(
            lambda numbers: sampler.hilbert.numbers_to_states(numbers),
            numbers,
            result_shape=jax.ShapeDtypeStruct(
                (chain_length * sampler.n_chains_per_rank, sampler.hilbert.size),
                jnp.float64,
            ),
        )
        samples = jnp.asarray(samples, dtype=sampler.dtype).reshape(
            chain_length, sampler.n_chains_per_rank, sampler.hilbert.size
        )

        return samples, state.replace(rng=new_rng)
Ejemplo n.º 9
0
def _sample_chain(
    sampler: ExactSampler,
    machine: nn.Module,
    parameters: PyTree,
    state: SamplerState,
    chain_length: int,
) -> Tuple[jnp.ndarray, SamplerState]:
    """
    Internal method used for jitting calls.
    """
    new_rng, rng = jax.random.split(state.rng)
    numbers = jax.random.choice(
        rng,
        sampler.hilbert.n_states,
        shape=(chain_length * sampler.n_chains_per_rank, ),
        replace=True,
        p=state.pdf,
    )

    # We use a host-callback to convert integers labelling states to
    # valid state-arrays because that code is written with numba and
    # we have not yet converted it to jax.
    #
    # For future investigators:
    # this will lead to a crash if numbers_to_state throws.
    # it throws if we feed it nans!
    samples = hcb.call(
        lambda numbers: sampler.hilbert.numbers_to_states(numbers),
        numbers,
        result_shape=jax.ShapeDtypeStruct(
            (chain_length * sampler.n_chains_per_rank, sampler.hilbert.size),
            jnp.float64,
        ),
    )
    samples = jnp.asarray(samples, dtype=sampler.dtype).reshape(
        chain_length, sampler.n_chains_per_rank, sampler.hilbert.size)

    return samples, state.replace(rng=new_rng)
Ejemplo n.º 10
0
    def _sample_next(sampler, machine, parameters, state):
        new_rng, rng = jax.random.split(state.rng)
        numbers = jax.random.choice(
            rng,
            sampler.hilbert.n_states,
            shape=(sampler.n_chains_per_rank, ),
            replace=True,
            p=state.pdf,
        )

        # We use a host-callback to convert integers labelling states to
        # valid state-arrays because that code is written with numba and
        # we have not yet converted it to jax.
        sample = hcb.call(
            lambda numbers: sampler.hilbert.numbers_to_states(numbers),
            numbers,
            result_shape=jax.ShapeDtypeStruct(
                (sampler.n_chains_per_rank, sampler.hilbert.size),
                jnp.float64),
        )

        new_state = state.replace(rng=new_rng)
        return new_state, jnp.asarray(sample, dtype=sampler.dtype)
Ejemplo n.º 11
0
    def error_norm(self, error_norm: Union[str, Callable]):
        if isinstance(error_norm, Callable):
            self._error_norm = error_norm
        elif error_norm == "euclidean":
            self._error_norm = euclidean_norm
        elif error_norm == "maximum":
            self._error_norm = maximum_norm
        elif error_norm == "qgt":
            w = self.state.parameters
            norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w))
            # QGT norm is called via host callback since it accesses the driver
            # TODO: make this also an hashablepartial on self to reduce recompilation
            self._error_norm = lambda x: hcb.call(
                HashablePartial(qgt_norm, self),
                x,
                result_shape=jax.ShapeDtypeStruct((), norm_dtype),
            )
        else:
            raise ValueError(
                "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum',"
                f" but {error_norm} was passed.")

        if self.integrator is not None:
            self.integrator.norm = self._error_norm
Ejemplo n.º 12
0
 def wrapped_exec(params):
     exec_fn = partial(self.execute_device, device=device)
     return host_callback.call(exec_fn,
                               params,
                               result_shape=jax.ShapeDtypeStruct(
                                   (1, ), JAXInterface.dtype))
Ejemplo n.º 13
0
    def __init__(
        self,
        operator: AbstractOperator,
        variational_state: VariationalState,
        integrator: RKIntegratorConfig,
        *,
        t0: float = 0.0,
        propagation_type="real",
        qgt: LinearOperator = None,
        linear_solver=None,
        linear_solver_restart: bool = False,
        error_norm: Union[str, Callable] = "euclidean",
    ):
        r"""
        Initializes the time evolution driver.

        Args:
            operator: The generator of the dynamics (Hamiltonian for pure states,
                Lindbladian for density operators).
            variational_state: The variational state.
            integrator: Configuration of the algorithm used for solving the ODE.
            t0: Initial time at the start of the time evolution.
            propagation_type: Determines the equation of motion: "real"  for the
                real-time Schödinger equation (SE), "imag" for the imaginary-time SE.
            qgt: The QGT specification.
            linear_solver: The solver for solving the linear system determining the time evolution.
            linear_solver_restart: If False (default), the last solution of the linear system
                is used as initial value in subsequent steps.
            error_norm: Norm function used to calculate the error with adaptive integrators.
                Can be either "euclidean" for the standard L2 vector norm :math:`w^\dagger w`,
                "maximum" for the maximum norm :math:`\max_i |w_i|`
                or "qgt", in which case the scalar product induced by the QGT :math:`S` is used
                to compute the norm :math:`\Vert w \Vert^2_S = w^\dagger S w` as suggested
                in PRL 125, 100503 (2020).
                Additionally, it possible to pass a custom function with signature
                :code:`norm(x: PyTree) -> float`
                which maps a PyTree of parameters :code:`x` to the corresponding norm.
                Note that norm is used in jax.jit-compiled code.
        """
        self._t0 = t0

        if linear_solver is None:
            linear_solver = nk.optimizer.solver.svd
        if qgt is None:
            qgt = QGTAuto(solver=linear_solver)

        super().__init__(variational_state,
                         optimizer=None,
                         minimized_quantity_name="Generator")

        self._generator_repr = repr(operator)
        if isinstance(operator, AbstractOperator):
            op = operator.collect()
            self._generator = lambda _: op
        else:
            self._generator = operator

        self.propagation_type = propagation_type
        if isinstance(variational_state, VariationalMixedState):
            # assuming Lindblad Dynamics
            # TODO: support density-matrix imaginary time evolution
            if propagation_type == "real":
                self._loss_grad_factor = 1.0
            else:
                raise ValueError(
                    "only real-time Lindblad evolution is supported for "
                    "mixed states")
        else:
            if propagation_type == "real":
                self._loss_grad_factor = -1.0j
            elif propagation_type == "imag":
                self._loss_grad_factor = -1.0
            else:
                raise ValueError(
                    "propagation_type must be one of 'real', 'imag'")

        self.qgt = qgt
        self.linear_solver = linear_solver
        self.linear_solver_restart = linear_solver_restart

        self._dw = None  # type: PyTree
        self._last_qgt = None

        if isinstance(error_norm, Callable):
            pass
        elif error_norm == "euclidean":
            error_norm = euclidean_norm
        elif error_norm == "maximum":
            error_norm = maximum_norm
        elif error_norm == "qgt":
            w = self.state.parameters
            norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w))
            # QGT norm is called via host callback since it accesses the driver
            error_norm = lambda x: hcb.call(
                HashablePartial(qgt_norm, self),
                x,
                result_shape=jax.ShapeDtypeStruct((), norm_dtype),
            )
        else:
            raise ValueError(
                "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum'."
            )

        self._odefun = HashablePartial(odefun_host_callback, self.state, self)
        self._integrator = integrator(
            self._odefun,
            t0,
            self.state.parameters,
            norm=error_norm,
        )

        self._stop_count = 0
Ejemplo n.º 14
0
    def wrapped_exec_bwd(params, g):

        if isinstance(gradient_fn, qml.gradients.gradient_transform):

            def non_diff_wrapper(args):
                """Compute the VJP in a non-differentiable manner."""
                new_tapes = []
                p = args[:-1]
                dy = args[-1]

                for t, a in zip(tapes, p):
                    new_tapes.append(t.copy(copy_operations=True))
                    new_tapes[-1].set_parameters(a)
                    new_tapes[-1].trainable_params = t.trainable_params

                vjp_tapes, processing_fn = qml.gradients.batch_vjp(
                    new_tapes,
                    dy,
                    gradient_fn,
                    reduction="append",
                    gradient_kwargs=gradient_kwargs,
                )

                partial_res = execute_fn(vjp_tapes)[0]
                res = processing_fn(partial_res)
                return np.concatenate(res)

            args = tuple(params) + (g, )
            vjps = host_callback.call(
                non_diff_wrapper,
                args,
                result_shape=jax.ShapeDtypeStruct((total_params, ), dtype),
            )

            param_idx = 0
            res = []

            # Group the vjps based on the parameters of the tapes
            for p in params:
                param_vjp = vjps[param_idx:param_idx + len(p)]
                res.append(param_vjp)
                param_idx += len(p)

            # Unwrap partial results into ndim=0 arrays to allow
            # differentiability with JAX
            # E.g.,
            # [DeviceArray([-0.9553365], dtype=float32), DeviceArray([0., 0.],
            # dtype=float32)]
            # is mapped to
            # [[DeviceArray(-0.9553365, dtype=float32)], [DeviceArray(0.,
            # dtype=float32), DeviceArray(0., dtype=float32)]].
            need_unwrapping = any(r.ndim != 0 for r in res)
            if need_unwrapping:
                unwrapped_res = []
                for r in res:
                    if r.ndim != 0:
                        r = [jnp.array(p) for p in r]
                    unwrapped_res.append(r)

                res = unwrapped_res

            return (tuple(res), )

        # Gradient function is a device method.
        with qml.tape.Unwrap(*tapes):
            jacs = gradient_fn(tapes, **gradient_kwargs)

        vjps = [qml.gradients.compute_vjp(d, jac) for d, jac in zip(g, jacs)]
        res = [[jnp.array(p) for p in v] for v in vjps]
        return (tuple(res), )