def solve_subproblem( self, A: SparseCooMatrix, ATb: hints.Array, lambd: hints.Scalar, iteration: hints.Scalar, # Unused ) -> jnp.ndarray: # JAX-compatible sparse Cholesky factorization with a host callback. Similar to: # self._solve(_LinearSolverArgs(A, ATb, lambd)) return hcb.call(self._solve, _LinearSolverArgs(A, ATb, lambd), result_shape=ATb)
def random_state_batch_spin_impl(hilb: Spin, key, batches, dtype): S = hilb._s shape = (batches, hilb.size) # If unconstrained space, use fast sampling if hilb._total_sz is None: n_states = int(2 * S + 1) rs = jax.random.randint(key, shape=shape, minval=0, maxval=n_states) two = jnp.asarray(2, dtype=dtype) return jnp.asarray(rs * two - (n_states - 1), dtype=dtype) else: N = hilb.size n_states = int(2 * S) + 1 # if constrained and S == 1/2, use a trick to sample quickly if n_states == 2: m = hilb._total_sz * 2 nup = (N + m) // 2 ndown = (N - m) // 2 x = jnp.concatenate( ( jnp.ones((batches, nup), dtype=dtype), -jnp.ones( ( batches, ndown, ), dtype=dtype, ), ), axis=1, ) # deprecated: return jax.random.shuffle(key, x, axis=1) return jax.vmap(jax.random.permutation)( jax.random.split(key, x.shape[0]), x ) # if constrained and S != 1/2, then use a slow fallback algorithm # TODO: find better, faster way to smaple constrained arbitrary spaces. else: from jax.experimental import host_callback as hcb cb = lambda rng: _random_states_with_constraint(hilb, rng, batches, dtype) state = hcb.call( cb, key, result_shape=jax.ShapeDtypeStruct(shape, dtype), ) return state
def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape): """The simplest implementation of calling to TF, without AD support. We must use hcb.call because the TF invocation must happen outside the JAX staged computation.""" def tf_to_numpy(t): # Turn the Tensor to NumPy array without copying. return np.asarray(memoryview(t)) if isinstance(t, tf.Tensor) else t return hcb.call( lambda arg: tf.nest.map_structure(tf_to_numpy, tf_fun(arg)), arg, result_shape=result_shape)
def wrapped_exec_bwd(params, g): def jacobian(params): tape = self.copy() tape.set_parameters(params) return tape.jacobian(device, params=params, **tape.jacobian_options) val = g.reshape((-1, )) * host_callback.call( jacobian, params, result_shape=jax.ShapeDtypeStruct( (1, len(params)), JAXInterface.dtype), ) return (list(val.reshape((-1, ))), ) # Comma is on purpose.
def odefun_host_callback(state, driver, *args, **kwargs): """ Calls odefun through a host callback in order to make the rest of the ODE solver jit-able. """ result_shape = jax.tree_map( lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), state.parameters, ) return hcb.call( lambda args_and_kw: odefun(state, driver, *args_and_kw[0], **args_and_kw[1]), # pack args and kwargs together, since host_callback passes a single argument: (args, kwargs), result_shape=result_shape, )
def random_state(hilb: Fock, key, batches: int, *, dtype=np.float32): shape = (batches, hilb.size) # If unconstrained space, use fast sampling if hilb.n_particles is None: rs = jax.random.randint(key, shape=shape, minval=0, maxval=hilb.n_max + 1) return jnp.asarray(rs, dtype=dtype) else: from jax.experimental import host_callback as hcb state = hcb.call( lambda rng: _random_states_with_constraint(hilb, rng, batches, dtype), key, result_shape=jax.ShapeDtypeStruct(shape, dtype), ) return state
def wrapped_exec(params): def wrapper(p): """Compute the forward pass.""" new_tapes = [] for t, a in zip(tapes, p): new_tapes.append(t.copy(copy_operations=True)) new_tapes[-1].set_parameters(a) with qml.tape.Unwrap(*new_tapes): res, _ = execute_fn(new_tapes, **gradient_kwargs) return res shapes = [ jax.ShapeDtypeStruct((1, ), dtype) for _ in range(total_size) ] res = host_callback.call(wrapper, params, result_shape=shapes) return res
def _sample_chain( sampler, machine: nn.Module, parameters: PyTree, state: SamplerState, chain_length: int, ) -> Tuple[jnp.ndarray, SamplerState]: # Reimplement sample_chain because we can sample the whole 'chain' in one # go, since it's not really a chain anyway. This will be much faster because # we call into python only once. new_rng, rng = jax.random.split(state.rng) numbers = jax.random.choice( rng, sampler.hilbert.n_states, shape=(chain_length * sampler.n_chains_per_rank,), replace=True, p=state.pdf, ) # We use a host-callback to convert integers labelling states to # valid state-arrays because that code is written with numba and # we have not yet converted it to jax. # # For future investigators: # this will lead to a crash if numbers_to_state throws. # it throws if we feed it nans! samples = hcb.call( lambda numbers: sampler.hilbert.numbers_to_states(numbers), numbers, result_shape=jax.ShapeDtypeStruct( (chain_length * sampler.n_chains_per_rank, sampler.hilbert.size), jnp.float64, ), ) samples = jnp.asarray(samples, dtype=sampler.dtype).reshape( chain_length, sampler.n_chains_per_rank, sampler.hilbert.size ) return samples, state.replace(rng=new_rng)
def _sample_chain( sampler: ExactSampler, machine: nn.Module, parameters: PyTree, state: SamplerState, chain_length: int, ) -> Tuple[jnp.ndarray, SamplerState]: """ Internal method used for jitting calls. """ new_rng, rng = jax.random.split(state.rng) numbers = jax.random.choice( rng, sampler.hilbert.n_states, shape=(chain_length * sampler.n_chains_per_rank, ), replace=True, p=state.pdf, ) # We use a host-callback to convert integers labelling states to # valid state-arrays because that code is written with numba and # we have not yet converted it to jax. # # For future investigators: # this will lead to a crash if numbers_to_state throws. # it throws if we feed it nans! samples = hcb.call( lambda numbers: sampler.hilbert.numbers_to_states(numbers), numbers, result_shape=jax.ShapeDtypeStruct( (chain_length * sampler.n_chains_per_rank, sampler.hilbert.size), jnp.float64, ), ) samples = jnp.asarray(samples, dtype=sampler.dtype).reshape( chain_length, sampler.n_chains_per_rank, sampler.hilbert.size) return samples, state.replace(rng=new_rng)
def _sample_next(sampler, machine, parameters, state): new_rng, rng = jax.random.split(state.rng) numbers = jax.random.choice( rng, sampler.hilbert.n_states, shape=(sampler.n_chains_per_rank, ), replace=True, p=state.pdf, ) # We use a host-callback to convert integers labelling states to # valid state-arrays because that code is written with numba and # we have not yet converted it to jax. sample = hcb.call( lambda numbers: sampler.hilbert.numbers_to_states(numbers), numbers, result_shape=jax.ShapeDtypeStruct( (sampler.n_chains_per_rank, sampler.hilbert.size), jnp.float64), ) new_state = state.replace(rng=new_rng) return new_state, jnp.asarray(sample, dtype=sampler.dtype)
def error_norm(self, error_norm: Union[str, Callable]): if isinstance(error_norm, Callable): self._error_norm = error_norm elif error_norm == "euclidean": self._error_norm = euclidean_norm elif error_norm == "maximum": self._error_norm = maximum_norm elif error_norm == "qgt": w = self.state.parameters norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w)) # QGT norm is called via host callback since it accesses the driver # TODO: make this also an hashablepartial on self to reduce recompilation self._error_norm = lambda x: hcb.call( HashablePartial(qgt_norm, self), x, result_shape=jax.ShapeDtypeStruct((), norm_dtype), ) else: raise ValueError( "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum'," f" but {error_norm} was passed.") if self.integrator is not None: self.integrator.norm = self._error_norm
def wrapped_exec(params): exec_fn = partial(self.execute_device, device=device) return host_callback.call(exec_fn, params, result_shape=jax.ShapeDtypeStruct( (1, ), JAXInterface.dtype))
def __init__( self, operator: AbstractOperator, variational_state: VariationalState, integrator: RKIntegratorConfig, *, t0: float = 0.0, propagation_type="real", qgt: LinearOperator = None, linear_solver=None, linear_solver_restart: bool = False, error_norm: Union[str, Callable] = "euclidean", ): r""" Initializes the time evolution driver. Args: operator: The generator of the dynamics (Hamiltonian for pure states, Lindbladian for density operators). variational_state: The variational state. integrator: Configuration of the algorithm used for solving the ODE. t0: Initial time at the start of the time evolution. propagation_type: Determines the equation of motion: "real" for the real-time Schödinger equation (SE), "imag" for the imaginary-time SE. qgt: The QGT specification. linear_solver: The solver for solving the linear system determining the time evolution. linear_solver_restart: If False (default), the last solution of the linear system is used as initial value in subsequent steps. error_norm: Norm function used to calculate the error with adaptive integrators. Can be either "euclidean" for the standard L2 vector norm :math:`w^\dagger w`, "maximum" for the maximum norm :math:`\max_i |w_i|` or "qgt", in which case the scalar product induced by the QGT :math:`S` is used to compute the norm :math:`\Vert w \Vert^2_S = w^\dagger S w` as suggested in PRL 125, 100503 (2020). Additionally, it possible to pass a custom function with signature :code:`norm(x: PyTree) -> float` which maps a PyTree of parameters :code:`x` to the corresponding norm. Note that norm is used in jax.jit-compiled code. """ self._t0 = t0 if linear_solver is None: linear_solver = nk.optimizer.solver.svd if qgt is None: qgt = QGTAuto(solver=linear_solver) super().__init__(variational_state, optimizer=None, minimized_quantity_name="Generator") self._generator_repr = repr(operator) if isinstance(operator, AbstractOperator): op = operator.collect() self._generator = lambda _: op else: self._generator = operator self.propagation_type = propagation_type if isinstance(variational_state, VariationalMixedState): # assuming Lindblad Dynamics # TODO: support density-matrix imaginary time evolution if propagation_type == "real": self._loss_grad_factor = 1.0 else: raise ValueError( "only real-time Lindblad evolution is supported for " "mixed states") else: if propagation_type == "real": self._loss_grad_factor = -1.0j elif propagation_type == "imag": self._loss_grad_factor = -1.0 else: raise ValueError( "propagation_type must be one of 'real', 'imag'") self.qgt = qgt self.linear_solver = linear_solver self.linear_solver_restart = linear_solver_restart self._dw = None # type: PyTree self._last_qgt = None if isinstance(error_norm, Callable): pass elif error_norm == "euclidean": error_norm = euclidean_norm elif error_norm == "maximum": error_norm = maximum_norm elif error_norm == "qgt": w = self.state.parameters norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w)) # QGT norm is called via host callback since it accesses the driver error_norm = lambda x: hcb.call( HashablePartial(qgt_norm, self), x, result_shape=jax.ShapeDtypeStruct((), norm_dtype), ) else: raise ValueError( "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum'." ) self._odefun = HashablePartial(odefun_host_callback, self.state, self) self._integrator = integrator( self._odefun, t0, self.state.parameters, norm=error_norm, ) self._stop_count = 0
def wrapped_exec_bwd(params, g): if isinstance(gradient_fn, qml.gradients.gradient_transform): def non_diff_wrapper(args): """Compute the VJP in a non-differentiable manner.""" new_tapes = [] p = args[:-1] dy = args[-1] for t, a in zip(tapes, p): new_tapes.append(t.copy(copy_operations=True)) new_tapes[-1].set_parameters(a) new_tapes[-1].trainable_params = t.trainable_params vjp_tapes, processing_fn = qml.gradients.batch_vjp( new_tapes, dy, gradient_fn, reduction="append", gradient_kwargs=gradient_kwargs, ) partial_res = execute_fn(vjp_tapes)[0] res = processing_fn(partial_res) return np.concatenate(res) args = tuple(params) + (g, ) vjps = host_callback.call( non_diff_wrapper, args, result_shape=jax.ShapeDtypeStruct((total_params, ), dtype), ) param_idx = 0 res = [] # Group the vjps based on the parameters of the tapes for p in params: param_vjp = vjps[param_idx:param_idx + len(p)] res.append(param_vjp) param_idx += len(p) # Unwrap partial results into ndim=0 arrays to allow # differentiability with JAX # E.g., # [DeviceArray([-0.9553365], dtype=float32), DeviceArray([0., 0.], # dtype=float32)] # is mapped to # [[DeviceArray(-0.9553365, dtype=float32)], [DeviceArray(0., # dtype=float32), DeviceArray(0., dtype=float32)]]. need_unwrapping = any(r.ndim != 0 for r in res) if need_unwrapping: unwrapped_res = [] for r in res: if r.ndim != 0: r = [jnp.array(p) for p in r] unwrapped_res.append(r) res = unwrapped_res return (tuple(res), ) # Gradient function is a device method. with qml.tape.Unwrap(*tapes): jacs = gradient_fn(tapes, **gradient_kwargs) vjps = [qml.gradients.compute_vjp(d, jac) for d, jac in zip(g, jacs)] res = [[jnp.array(p) for p in v] for v in vjps] return (tuple(res), )