class QGTJacobianDenseT(LinearOperator): """ Semi-lazy representation of an S Matrix behaving like a linear operator. The matrix of gradients O is computed on initialisation, but not S, which can be computed by calling :code:`to_dense`. The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in the field `sr`. """ O: jnp.ndarray = Uninitialized """Gradients O_ij = ∂log ψ(σ_i)/∂p_j of the neural network for all samples σ_i at given values of the parameters p_j Average <O_j> subtracted for each parameter Divided through with sqrt(#samples) to normalise S matrix If scale is not None, columns normalised to unit norm """ scale: Optional[jnp.ndarray] = None """If not None, contains 2-norm of each column of the gradient matrix, i.e., the sqrt of the diagonal elements of the S matrix """ mode: str = struct.field(pytree_node=False, default=Uninitialized) """Differentiation mode: - "real": for real-valued R->R and C->R ansatze, splits the complex inputs into real and imaginary part. - "complex": for complex-valued R->C and C->C ansatze, splits the complex inputs and outputs into real and imaginary part - "holomorphic": for any ansatze. Does not split complex values. - "auto": autoselect real or complex. """ _in_solve: bool = struct.field(pytree_node=False, default=False) """Internal flag used to signal that we are inside the _solve method and matmul should not take apart into real and complex parts the other vector""" def __matmul__( self, vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]: return _matmul(self, vec) def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None) -> PyTree: return _solve(self, solve_fun, y, x0=x0) def to_dense(self) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation. Returns: A dense matrix representation of this S matrix. """ return _to_dense(self) def __repr__(self): return (f"QGTJacobianDense(diag_shift={self.diag_shift}, " f"scale={self.scale}, mode={self.mode})")
class TiedAutoEncoder: latents: int = struct.field(False) features: int = struct.field(False) def __call__(self, scope, x): z = self.encode(scope, x) return self.decode(scope, z) def encode(self, scope, x): assert x.shape[-1] == self.features return self._tied(nn.dense)(scope, x, self.latents, bias=False) def decode(self, scope, z): assert z.shape[-1] == self.latents return self._tied(nn.dense, transpose=True)(scope, z, self.features, bias=False) def _tied(self, fn, transpose=False): if not transpose: return fn def trans(variables): if 'param' not in variables: return variables params = variables['param'] params['kernel'] = params['kernel'].T return variables return lift.transform_module( fn, trans_in_fn=trans, trans_out_fn=trans)
class GaussianProcess: index_points: jnp.ndarray mean_function: Callable = struct.field(pytree_node=False) kernel_function: Callable = struct.field(pytree_node=False) jitter: float def marginal(self): kxx = self.kernel_function(self.index_points, self.index_points) chol_kxx = jnp.linalg.cholesky(utils.diag_shift(kxx, self.jitter)) mean = self.mean_function(self.index_points) return distributions.MultivariateNormalTriL(mean, chol_kxx) def posterior_gp(self, y, x_new, observation_noise_variance, jitter=None): """ Returns a new GP conditional on y. """ cond_kernel_fn, _ = kernels.SchurComplementKernelProvider.init( None, self.kernel_function, self.index_points, observation_noise_variance) marginal = self.marginal() def cond_mean_fn(x): k_xnew_x = self.kernel_function(x, self.index_points) return (self.mean_function(x) + k_xnew_x @ jscipy.linalg.cho_solve( (cond_kernel_fn.divisor_matrix_cholesky, True), y - marginal.mean)) jitter = jitter if jitter else self.jitter return GaussianProcess(x_new, cond_mean_fn, cond_kernel_fn, jitter)
class LocalShardedParameterStats: """State associated to each parameter of the model being trained.""" diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner diagonal_momentum: chex.Array # Momentum for the diagonal preconditioner momentum: chex.Array # Momentum for the shampoo preconditioner index_start: np.int32 = struct.field( pytree_node=False) # Index into global statistics array sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
class QGTOnTheFlyT(LinearOperator): """ Lazy representation of an S Matrix computed by performing 2 jvp and 1 vjp products, using the variational state's model, the samples that have already been computed, and the vector. The S matrix is not computed yet, but can be computed by calling :code:`to_dense`. The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in the field `sr`. """ apply_fun: Callable[[PyTree, jnp.ndarray], jnp.ndarray] = struct.field(pytree_node=False, default=Uninitialized) """The forward pass of the Ansatz.""" params: PyTree = Uninitialized """The first input to apply_fun (parameters of the ansatz).""" samples: jnp.ndarray = Uninitialized """The second input to apply_fun (points where the ansatz is evaluated).""" model_state: Optional[PyTree] = None """Optional state of the ansataz.""" centered: bool = struct.field(pytree_node=False, default=True) """Uses S=⟨ΔÔᶜΔÔ⟩ if True (default), S=⟨ÔᶜΔÔ⟩ otherwise. The two forms are mathematically equivalent, but might lead to different results due to numerical precision. The non-centered variant should be approximately 33% faster. """ def __post_init__(self): super().__post_init__() if jnp.ndim(self.samples) != 2: samples_r = self.samples.reshape((-1, self.samples.shape[-1])) object.__setattr__(self, "samples", samples_r) def __matmul__(self, y): return onthefly_mat_treevec(self, y) def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs) -> PyTree: return _solve(self, solve_fun, y, x0=x0) def to_dense(self) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation.s Returns: A dense matrix representation of this S matrix. """ return _to_dense(self)
class CustomRuleNumpy(MetropolisRule): operator: Any = struct.field(pytree_node=False) weight_list: Any = struct.field(pytree_node=False, default=None) def __post_init__(self): # Raise errors if hilbert is not an Hilbert if not isinstance(self.operator, AbstractOperator): raise TypeError( "Argument to CustomRuleNumpy must be a valid operator.".format( type(operator) ) ) _check_operators(self.operator.operators) if self.weight_list is not None: if self.weight_list.shape != (self.operator.n_operators,): raise ValueError("move_weights have the wrong shape") if self.weight_list.min() < 0: raise ValueError("move_weights must be positive") else: object.__setattr__( self, "weight_list", np.ones(self.operator.n_operators, dtype=np.float32), ) object.__setattr__( self, "weight_list", self.weight_list / self.weight_list.sum() ) def init_state(rule, sampler, machine, params, key): return CustomRuleState( sections=np.empty(sampler.n_batches, dtype=np.int32), rand_op_n=np.empty(sampler.n_batches, dtype=np.int32), weight_cumsum=rule.weight_list.cumsum(), ) def transition(rule, sampler, machine, parameters, state, rng, σ): rule_state = state.rule_state _pick_random_and_init( σ.shape[0], rule_state.weight_cumsum, out=rule_state.rand_op_n, ) σ_conns, mels = rule.operator.get_conn_filtered( state.σ, rule_state.sections, rule_state.rand_op_n ) _choose_and_return( state.σ1, σ_conns, mels, rule_state.sections, state.log_prob_corr )
class TrainState(train_state.TrainState): """Train state with an Optax optimizer. The two functions below differ depending on whether the task is classification or regression. Args: logits_fn: Applied to last layer to obtain the logits. loss_fn: Function to compute the loss. """ logits_fn: Callable = struct.field(pytree_node=False) loss_fn: Callable = struct.field(pytree_node=False)
class TrainingState(struct.PyTreeNode): """Training state for distance model.""" step: int epoch: int best_loss: Dict[str, float] encoder_fn: Callable[[Variables, jnp.ndarray], jnp.ndarray] = (struct.field(pytree_node=False)) distance_fn: Callable[[Variables, jnp.ndarray], jnp.ndarray] = (struct.field(pytree_node=False)) domain_discriminator_fn: Callable[[Variables, jnp.ndarray], jnp.ndarray] = (struct.field( pytree_node=False)) distance_optimizer: optax.OptState domain_optimizer: optax.OptState
class MultiScaleRBIGBlockInit: init_functions: List[dataclass] filter_shape: Tuple[int, int] = struct.field(pytree_node=False) image_shape: Tuple = struct.field(pytree_node=False) def __post_init__(self): self.ms_reshape = init_scale_function(self.filter_shape, self.image_shape) def forward_and_params(self, inputs: Array) -> Tuple[Array, Array]: # rescale data inputs = self.ms_reshape.forward(inputs) outputs = inputs bijectors = [] # loop through bijectors for ibijector in self.init_functions: # transform and params outputs, ibjector = ibijector.bijector_and_transform(outputs) # accumulate params bijectors.append(ibjector) # unrescale data outputs = self.ms_reshape.inverse(outputs) # create bijector chain bijectors = MultiScaleBijector( bijectors=bijectors, filter_shape=self.filter_shape, image_shape=self.image_shape, ) return outputs, bijectors def forward(self, inputs: Array) -> Array: # rescale data inputs = self.ms_reshape.forward(inputs) outputs = inputs for ibijector in self.init_functions: outputs = ibijector.transform(outputs) # unrescale data outputs = self.ms_reshape.inverse(outputs) return outputs
class EmaState: decay: float = struct.field(pytree_node=False, default=0.) params: flax.core.FrozenDict = None model_state: flax.core.FrozenDict = None @staticmethod def create(decay, params, model_state): """Initialize ema state""" if decay == 0.: # default state == disabled return EmaState() ema_params = jax.tree_map(lambda x: x, params) ema_model_state = jax.tree_map(lambda x: x, model_state) return EmaState(decay, ema_params, ema_model_state) def update(self, new_params, new_model_state): if self.decay == 0.: return self.replace(params=None, model_state=None) new_params = jax.tree_multimap( lambda ema, p: ema * self.decay + (1. - self.decay) * p, self.params, new_params) new_model_state = jax.tree_multimap( lambda ema, s: ema * self.decay + (1. - self.decay) * s, self.model_state, new_model_state) return self.replace(params=new_params, model_state=new_model_state)
class ConditionalGaussianizationFlow(BijectorChain): bijectors: Iterable[Bijector] base_dist: Distribution = struct.field(pytree_node=False) encoder: Callable def score_samples(self, inputs, outputs): # forward propagation z, log_det = self.forward_and_log_det(inputs) # encode params y_dist = self.encoder.forward(outputs) # calculate latent probability latent_prob = y_dist.log_prob(z) # calculate log probability log_prob = latent_prob.sum(axis=1) + log_det.sum(axis=1) return log_prob def score(self, inputs, outputs): return -jnp.mean(self.score_samples(inputs, outputs)) def sample(self, outputs: Array, seed: int, n_samples: int): # encode params y_dist = self.encoder.forward(outputs) # generate Gaussian samples X_g_samples = y_dist.sample(seed=seed, sample_shape=n_samples) # # inverse transformation return self.inverse(X_g_samples)
class SRLazy(SR): """ Base class holding the parameters for the iterative solution of the SR system x = ⟨S⟩⁻¹⟨F⟩, where S is a lazy linear operator Tolerances are applied according to the formula :code:`norm(residual) <= max(tol*norm(b), atol)` """ tol: float = 1.0e-5 """Relative tolerance for convergences.""" atol: float = 0.0 """Absolutes tolerance for convergences.""" maxiter: int = None """Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved. """ M: Optional[Union[Callable, Array]] = None """Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance. """ centered: bool = struct.field(pytree_node=False, default=True) """Uses S=⟨ΔÔᶜΔÔ⟩ if True (default), S=⟨ÔᶜΔÔ⟩ otherwise. The two forms are mathematically equivalent, but might lead to different results due to numerical precision. The non-centered variaant should bee approximately 33% faster. """ def create(self, *args, **kwargs): return LazySMatrix(*args, **kwargs)
class TrainState: """Container for misc training state that's not handeled by the optimizer. This includes: - The base RNG key for each step, replicated across devices. - Any metrics output by the training step (that are then logged to the history object) """ history: TrainStateHistory = struct.field(pytree_node=False) rng: Any step: Any metrics: Any def take_step(self, optimizer, grad, metrics, rng): if isinstance(optimizer.state, list): step = optimizer.state[0].step else: step = optimizer.state.step new_optimizer = optimizer.apply_gradient( grad, learning_rate=self.history.learning_rate_fn(step)) new_train_state = self.replace(rng=rng, step=step, metrics=metrics) return new_optimizer, new_train_state def write_history(self): step = self.step[0] self.history.write(step, self.metrics) return self.replace(step=None, metrics=None)
class TrainState(): """Container for misc training state that's not handeled by the optimizer. This includes: - The base RNG key for each step, replicated across devices. - Any metrics output by the training step (that are then logged to the history object) """ rng: Any step: Any metrics: Any history: TrainStateHistory = struct.field(pytree_node=False) def take_step(self, optimizer, grad, metrics, rng): step = optimizer.state.step new_optimizer = optimizer.apply_gradient( grad, learning_rate=self.history.learning_rate_fn(step)) # TODO(marcvanzee): Remove this when b/162398046 is fixed. new_train_state = self.replace(rng=rng, step=step, metrics=metrics) # pytype: disable=attribute-error return new_optimizer, new_train_state def write_history(self): step = self.step[0] self.history.write(step, self.metrics) return self.replace(step=None, metrics=None) # pytype: disable=attribute-error
class TrainState(struct.PyTreeNode): """Simple train state for the common case with a single Optax optimizer. """ # New with respect to Flax definition: ema_params: ParamDict # Same as in Flax definition of train_state: step: int params: ParamDict tx: optax.GradientTransformation = struct.field(pytree_node=False) opt_state: optax.OptState def apply_gradients(self, *, grads, ema_momentum=0.9999, **kwargs): """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. Note that internally this function calls `.tx.update()` followed by a call to `optax.apply_updates()` to update `params` and `opt_state`. Args: grads: Gradients that have the same pytree structure as `.params`. ema_momentum: momentum for EMA updates. **kwargs: Additional dataclass attributes that should be `.replace()`-ed. Returns: An updated instance of `self` with `step` incremented by one, `params` and `opt_state` updated by applying `grads`, and additional attributes replaced as specified by `kwargs`. """ updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) new_params = optax.apply_updates(self.params, updates) def update_ema(ema, p): return ema_momentum * ema + (1. - ema_momentum) * p new_ema_params = jax.tree_multimap(update_ema, self.ema_params, new_params) return self.replace( step=self.step + 1, params=new_params, ema_params=new_ema_params, opt_state=new_opt_state, **kwargs, ) @classmethod def create(cls, *, params, tx, **kwargs): """Creates a new instance with `step=0` and initialized `opt_state`.""" opt_state = tx.init(params) return cls( step=0, ema_params=params, # Deepcopy to avoid donating the same object twice in train_step. params=copy.deepcopy(params), tx=tx, opt_state=opt_state, **kwargs, )
class Model: """DEPRECATION WARNING: The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/master/flax/linen/README.md A Model contains the model parameters, state and definition.""" module: Type[Module] = struct.field(pytree_node=False) params: Any = struct.field(pytree_node=True) def __call__(self, *args, **kwargs): return self.module.call(self.params, *args, **kwargs) def truncate_at(self, module_path): """Truncates the model by returning the outputs of the given sub-module. Args: module_path: the full name of the module (e.g. '/module/sub_module'). A list or dict of module paths can be provided to obtain the intermediate outputs of multiple modules. Returns: A new model with the truncated outputs. If module_path is a pytree of paths the outputs will be have the same structure where each path is replaced by the corresponding intermediate output. """ truncated_module_cls = TruncatedModule.partial( wrapped_module=self.module, truncate_path=module_path) return self.replace(module=truncated_module_cls) def __getattr__(self, name): value = getattr(self.module, name) if inspect.isclass(value) and issubclass(value, Module): def wrapper(*args, **kwargs): return value.call(self.params, *args, **kwargs) return wrapper raise AttributeError(f'No attribute named "{name}".') def __hash__(self): # Jax will call hash when the model is passed to a function transform. # The compiled function should not be shared among model instances because # it closes over the specific parameters of this model instance. return id(self)
class Kernel: kernel_fn: Callable = struct.field(pytree_node=False) def apply(self, x, x2): return self.kernel_fn(x, x2) def __call__(self, x, x2=None): x2 = x if x2 is None else x2 return self.apply(x, x2)
class SRLazyGMRES(SRLazy): """ Computes x = ⟨S⟩⁻¹⟨F⟩ by using an iterative GMRES method. See `Jax docs <https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.gmres.html#jax.scipy.sparse.linalg.gmres>`_ for more informations. """ restart: int = struct.field(pytree_node=False, default=20) """Size of the Krylov subspace (“number of iterations”) built between restarts. GMRES works by approximating the true solution x as its projection into a Krylov space of this dimension - this parameter therefore bounds the maximum accuracy achievable from any guess solution. Larger values increase both number of iterations and iteration cost, but may be necessary for convergence. The algorithm terminates early if convergence is achieved before the full subspace is built. Default is 20 """ solve_method: str = struct.field(pytree_node=False, default="batched") """(‘incremental’ or ‘batched’) – The ‘incremental’ solve method builds a QR decomposition for the Krylov subspace incrementally during the GMRES process using Givens rotations. This improves numerical stability and gives a free estimate of the residual norm that allows for early termination within a single “restart”. In contrast, the ‘batched’ solve method solves the least squares problem from scratch at the end of each GMRES iteration. It does not allow for early termination, but has much less overhead on GPUs. """ def solve_fun(self): return partial( jax.scipy.sparse.linalg.gmres, tol=self.tol, atol=self.atol, maxiter=self.maxiter, M=self.M, restart=self.restart, solve_method=self.solve_method, )
class ConditionalModel: params: dict model: Callable = struct.field(pytree_node=False) def forward(self, inputs) -> Tuple[Array, Array]: # forward pass for params outputs = self.model.apply(self.params, inputs) # split params split = outputs.shape[1] // 2 # compute means and log stds means = outputs[..., :split] log_stds = outputs[..., split:] dist = LogStddevNormal(loc=means, log_scale=log_stds) return dist
class HamiltonianRuleNumpy(MetropolisRule): """ Rule for Numpy sampler backend proposing moves according to the terms in an operator. In this case, the transition matrix is taken to be: .. math:: T( \\mathbf{s} \\rightarrow \\mathbf{s}^\\prime) = \\frac{1}{\\mathcal{N}(\\mathbf{s})}\\theta(|H_{\\mathbf{s},\\mathbf{s}^\\prime}|), """ operator: AbstractOperator = struct.field(pytree_node=False) """The (hermitian) operator giving the transition amplitudes.""" def __post_init__(self): # Raise errors if hilbert is not an Hilbert if not isinstance(self.operator, AbstractOperator): raise TypeError( "Argument to HamiltonianRuleNumpy must be a valid operator.". format(type(operator))) def init_state(rule, sampler, machine, params, key): if sampler.hilbert != rule.operator.hilbert: raise ValueError(f""" The hilbert space of the sampler ({sampler.hilbert}) and the hilbert space of the operator ({rule.operator.hilbert}) for HamiltonianRule must be the same. """) return HamiltonianRuleState( sections=np.empty(sampler.n_batches, dtype=np.int32)) def transition(rule, sampler, machine, parameters, state, rng, σ): σ = state.σ σ1 = state.σ1 log_prob_corr = state.log_prob_corr sections = state.rule_state.sections σp = rule.operator.get_conn_flattened(σ, sections)[0] rand_vec = rng.uniform(0, 1, size=σ.shape[0]) _choose(σp, sections, σ1, log_prob_corr, rand_vec) rule.operator.n_conn(σ1, sections) log_prob_corr -= np.log(sections) def __repr__(self): return f"HamiltonianRuleNumpy({self.operator})"
class EmaState: decay: float = struct.field(pytree_node=False, default=0.) variables: flax.core.FrozenDict[str, Any] = None @staticmethod def create(decay, variables): """Initialize ema state""" if decay == 0.: # default state == disabled return EmaState() ema_variables = jax.tree_map(lambda x: x, variables) return EmaState(decay, ema_variables) def update(self, new_variables): if self.decay == 0.: return self.replace(variables=None) new_ema_variables = jax.tree_multimap( lambda ema, p: ema * self.decay + (1. - self.decay) * p, self.variables, new_variables) return self.replace(variables=new_ema_variables)
class Model: """A Model contains the model paramaters, state and definition.""" module: Module = struct.field(pytree_node=False) params: Any def __call__(self, *args, **kwargs): return self.module.call(self.params, *args, **kwargs) def truncate_at(self, module_path): """Truncate the model by returning the outputs of the given sub-module. Args: module_path: the full name of the module (eg. '/module/sub_module'). A list or dict of module paths can be provided to obtain the intermediate outputs of multiple modules. Returns: A new model with the truncated outputs. If module_path is a pytree of paths the outputs will be have the same structure where each path is replaced by the corresponding intermediate output. """ truncated_module_cls = TruncatedModule.partial( wrapped_module=self.module, truncate_path=module_path) return self.replace(module=truncated_module_cls) def __getattr__(self, name): value = getattr(self.module, name) if inspect.isclass(value) and issubclass(value, Module): def wrapper(*args, **kwargs): return value.call(self.params, *args, **kwargs) return wrapper raise AttributeError(f'No attribute named "{name}".') def __hash__(self): # Jax will call hash when model is passed to a function transform. # the compiled function should not be shared among model instances because # it closes over the specific parameters of this model instance. return id(self)
class QGTOnTheFlyT(LinearOperator): """ Lazy representation of an S Matrix computed by performing 2 jvp and 1 vjp products, using the variational state's model, the samples that have already been computed, and the vector. The S matrix is not computed yet, but can be computed by calling :code:`to_dense`. The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contained in the field `sr`. """ _mat_vec: Callable[[PyTree, float], PyTree] = Uninitialized """The S matrix-vector product as generated by mat_vec_factory. It's a jax.Partial, so can be used as pytree_node.""" _params: PyTree = Uninitialized """The first input to apply_fun (parameters of the ansatz). Only used as a shape placeholder.""" _chunking: bool = struct.field(pytree_node=False, default=False) """Whether the implementation with chunks is used which currently does not support vmapping over it""" def __matmul__(self, y): return onthefly_mat_treevec(self, y) def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs) -> PyTree: return _solve(self, solve_fun, y, x0=x0) def to_dense(self) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation. Returns: A dense matrix representation of this S matrix. """ return _to_dense(self) def __repr__(self): return f"QGTOnTheFly(diag_shift={self.diag_shift})"
class GraphParameters: """Holds a graph's characteristics. Attributes: node_vocab_size: Number of different node types. num_relation_types: Number of different edge types. node_feature_dim: Dimension of the fixed representation of each node. node_feature_kind: Determines whether the node features are categorical or real. task_vocab_size: Number of different types of tasks. task_feature_dim: Dimension of the fixed representation of each task. task_feature_kind: Determines whether the task features are categorical or real. """ node_vocab_size: int = struct.field(pytree_node=False) num_relation_types: int = struct.field(pytree_node=False) node_feature_dim: int = struct.field(pytree_node=False) node_feature_kind: NodeFeatureKind = struct.field(pytree_node=False) task_vocab_size: int = struct.field(pytree_node=False) task_feature_dim: int = struct.field(pytree_node=False) task_feature_kind: NodeFeatureKind = struct.field(pytree_node=False)
class LazySMatrix: """ Lazy representation of an S Matrix behving like a linear operator. The S matrix is not computed yet, but can be computed by calling :code:`to_dense`. The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in the field `sr`. """ apply_fun: Callable[[PyTree, jnp.ndarray], jnp.ndarray] = struct.field(pytree_node=False) """The forward pass of the Ansatz.""" params: PyTree """The first input to apply_fun (parameters of the ansatz).""" samples: jnp.ndarray """The second input to apply_fun (points where the ansatz is evaluated).""" sr: SRLazy """Parameters for the solution of the system.""" model_state: Optional[PyTree] = None """Optional state of the ansataz.""" x0: Optional[PyTree] = None """Optional initial guess for the iterative solution.""" def __matmul__(self, vec): return lazysmatrix_mat_treevec(self, vec) def __rtruediv__(self, y): return self.solve(y) def solve(self, y: PyTree, x0: Optional[PyTree] = None) -> PyTree: """ Solve the linear system x=⟨S⟩⁻¹⟨y⟩ with the chosen iterataive solver. Args: y: the vector y in the system above. x0: optional initial guess for the solution. Returns: x: the PyTree solving the system. info: optional additional informations provided by the solver. Might be None if there are no additional informations provided. """ if x0 is None: x0 = self.x0 out, info = apply_onthefly( self, y, x0, ) return out, info @jax.jit def to_dense(self) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation.s Returns: A dense matrix representation of this S matrix. """ Npars = nkjax.tree_size(self.params) I = jax.numpy.eye(Npars) return jax.vmap(lambda S, x: self @ x, in_axes=(None, 0))(self, I)
class Sampler(abc.ABC): """ Abstract base class for all samplers. It contains the fields that all of them should posses, defining the common API. Note that fields marked with pytree_node=False are treated as static arguments when jitting. """ hilbert: AbstractHilbert = struct.field(pytree_node=False) """Hilbert space to be sampled.""" n_chains: int = struct.field(pytree_node=False, default=16) """Number of batches along the chain""" machine_pow: int = struct.field(default=2) """Exponent of the pdf sampled""" dtype: type = struct.field(pytree_node=False, default=np.float64) """DType of the states returned.""" def __post_init__(self): # Raise errors if hilbert is not an Hilbert if not isinstance(self.hilbert, AbstractHilbert): raise ValueError( "hilbert must be a subtype of netket.hilbert.AbstractHilbert, " + "instead, type {} is not.".format(type(self.hilbert))) if not isinstance(self.n_chains, int) and self.n_chains >= 0: raise ValueError("n_chains must be a positivee integer") # if not isinstance(self.machine_pow, int) and self.machine_pow>= 0: # raise ValueError("machine_pow must be a positivee integer") @property def n_batches(self) -> int: """ The batch size of the configuration $\sigma$ used by this sampler. In general, it is equivalent to :attr:`~Sampler.n_chains`. """ return self.n_chains def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable: """ Returns a closure with the log_pdf function encoded by this sampler. Note: the result is returned as an HashablePartial so that the closure does not trigger recompilation. Args: model: The machine, or apply_fun Returns: the log probability density function """ apply_fun = get_afun_if_module(model) log_pdf = HashablePartial( lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ). real, apply_fun, ) return log_pdf def init_state( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, seed: Optional[SeedType] = None, ) -> SamplerState: """ Creates the structure holding the state of the sampler. If you want reproducible samples, you should specify `seed`, otherwise the state will be initialised randomly. If running across several MPI processes, all sampler_states are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds. The resulting state is guaranteed to be a frozen python dataclass (in particular, a flax's dataclass), and it can be serialized using Flax serialization methods. Args: machine: a Flax module or callable with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used. Returns: The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use. """ key = nkjax.PRNGKey(seed) return sampler._init_state(get_afun_if_module(machine), parameters, nkjax.mpi_split(key)) def reset( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> SamplerState: """ Resets the state of the sampler. To be used every time the parameters are changed. Args: machine: a Flax module or callable with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If it's not provided, it will be constructed by calling :code:`sampler.init_state(machine, parameters)` with a random seed. Returns: A valid sampler state. """ if state is None: state = sampler_state(sampler, machine, parameters) return sampler._reset(get_afun_if_module(machine), parameters, state) def sample_next( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples the next state in the markov chain. Args: machine: a Flax module or callable apply function with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If it's not provided, it will be constructed by calling :code:`sampler.reset(machine, parameters)` with a random seed. Returns: state: The new state of the sampler σ: The next batch of samples. """ if state is None: state = sampler_state(sampler, machine, parameters) return sampler._sample_next(get_afun_if_module(machine), parameters, state) def sample( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, *, state: Optional[SamplerState] = None, chain_length: int = 1, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples chain_length elements along the chains. Arguments: sampler: The Monte Carlo sampler. machine: The model or callable to sample from (if it's a function it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`). parameters: The PyTree of parameters of the model. state: current state of the sampler. If None, then initialises it. chain_length: (default=1), the length of the chains. Returns: state: The new state of the sampler σ: The next batch of samples. """ return sample(sampler, machine, parameters, state=state, chain_length=chain_length) @partial(jax.jit, static_argnums=(1, 4)) def _sample_chain( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: SamplerState, chain_length: int, ) -> Tuple[jnp.ndarray, SamplerState]: _sample_next = lambda state, _: sampler.sample_next( machine, parameters, state) state, samples = jax.lax.scan( _sample_next, state, xs=None, length=chain_length, ) return samples, state @abc.abstractmethod def _init_state(sampler, machine, params, seed) -> SamplerState: """ Implementation of init_state for subclasses of Sampler. If you sub-class Sampler, you should define this and not init_state itself, because init_state contains some common logic. """ raise NotImplementedError("init_state Not Implemented") @abc.abstractmethod def _reset(sampler, machine, parameters, state): """ Implementation of reset for subclasses of Sampler. If you sub-class Sampler, you should define _reset and not reset itself, because reset contains some common logic. """ raise NotImplementedError("reset Not Implemented") @abc.abstractmethod def _sample_next(sampler, machine, parameters, state=None): """ Implementation of sample_next for subclasses of Sampler. If you sub-class Sampler, you should define _sample_next and not sample_next itself, because reset contains some common logic. """ raise NotImplementedError("sample_next Not Implemented")
class MetropolisPtSampler(MetropolisSampler): """ Metropolis-Hastings with Parallel Tempering sampler. This sampler samples an Hilbert space, producing samples off a specific dtype. The samples are generated according to a transition rule that must be specified. """ n_replicas: int = struct.field(pytree_node=False, default=32) """The number of replicas""" def __init__( self, hilbert: AbstractHilbert, rule: MetropolisRule, *, n_replicas: int = 32, **kwargs, ): """ ``MetropolisSampler`` is a generic Metropolis-Hastings sampler using a transition rule to perform moves in the Markov Chain. The transition kernel is used to generate a proposed state :math:`s^\prime`, starting from the current state :math:`s`. The move is accepted with probability .. math:: A(s\\rightarrow s^\\prime) = \\mathrm{min}\\left (1,\\frac{P(s^\\prime)}{P(s)} F(e^{L(s,s^\\prime)})\\right), where the probability being sampled from is :math:`P(s)=|M(s)|^p. Here ::math::`M(s)` is a user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`, and :math:`L(s,s^\\prime)` is a suitable correcting factor computed by the transition kernel. Args: hilbert: The hilbert space to sample rule: A `MetropolisRule` to generate random transitions from a given state as well as uniform random states. n_chains: The number of Markov Chain to be run in parallel on a single process. n_sweeps: The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine). n_chains: The number of batches of the states to sample (default = 8) machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the statees sampled (default = np.float32). """ if not config.FLAGS["NETKET_EXPERIMENTAL"]: raise RuntimeError(""" Parallel Tempering samplers are under development and are known not to work. If you want to debug it, set the environment variable NETKET_EXPERIMENTAL=1 """) object.__setattr__(self, "n_replicas", n_replicas) super().__init__(hilbert, rule, **kwargs) def __post_init__(self): super().__post_init__() if (not isinstance(self.n_replicas, int) and self.n_replicas > 0 and np.mod(self.n_replicas, 2) == 0): raise ValueError("n_replicas must be an even integer > 0.") @property def n_batches(self): return self.n_chains * self.n_replicas def _init_state(sampler, machine, params: PyTree, key: PRNGKey) -> MetropolisPtSamplerState: key_state, key_rule = jax.random.split(key, 2) σ = jnp.zeros( (sampler.n_batches, sampler.hilbert.size), dtype=sampler.dtype, ) rule_state = sampler.rule.init_state(sampler, machine, params, key_rule) beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas beta = jnp.tile(beta, (sampler.n_chains, 1)) return MetropolisPtSamplerState( σ=σ, rng=key_state, rule_state=rule_state, n_samples=0, n_accepted=0, beta=beta, beta_0_index=jnp.zeros((sampler.n_chains, ), dtype=int), n_accepted_per_beta=jnp.zeros( (sampler.n_chains, sampler.n_replicas), dtype=int), beta_position=jnp.zeros((sampler.n_chains, )), beta_diffusion=jnp.zeros((sampler.n_chains, )), exchange_steps=0, ) def _reset(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): new_rng, rng = jax.random.split(state.rng) σ = sampler.rule.random_state(sampler, machine, parameters, state, rng) rule_state = sampler.rule.reset(sampler, machine, parameters, state) beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas beta = jnp.tile(beta, (sampler.n_chains, 1)) return state.replace( σ=σ, rng=new_rng, rule_state=rule_state, n_samples=0, n_accepted=0, n_accepted_per_beta=jnp.zeros( (sampler.n_chains, sampler.n_replicas)), beta_position=jnp.zeros((sampler.n_chains, )), beta_diffusion=jnp.zeros((sampler.n_chains)), exchange_steps=0, # beta=beta, # beta_0_index=jnp.zeros((sampler.n_chains,), dtype=jnp.int32), ) def _sample_next(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): new_rng, rng = jax.random.split(state.rng) # def cbr(data): # new_rng, rng = data # print("sample_next newrng:\n", new_rng, "\nand rng:\n", rng) # return new_rng # new_rng = hcb.call( # cbr, # (new_rng, rng), # result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype), # ) with loops.Scope() as s: s.key = rng s.σ = state.σ s.log_prob = sampler.machine_pow * machine(parameters, state.σ).real s.beta = state.beta # for logging s.beta_0_index = state.beta_0_index s.n_accepted_per_beta = state.n_accepted_per_beta s.beta_position = state.beta_position s.beta_diffusion = state.beta_diffusion for i in s.range(sampler.n_sweeps): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s.key, key1, key2, key3, key4 = jax.random.split(s.key, 5) # def cbi(data): # i, beta = data # print("sweep #", i, " for beta=\n", beta) # return beta # beta = hcb.call( # cbi, # (i, s.beta), # result_shape=jax.ShapeDtypeStruct(s.beta.shape, s.beta.dtype), # ) beta = s.beta σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s.σ) proposal_log_prob = sampler.machine_pow * machine( parameters, σp).real uniform = jax.random.uniform(key2, shape=(sampler.n_batches, )) if log_prob_correction is not None: do_accept = uniform < jnp.exp( beta.reshape((-1, )) * (proposal_log_prob - s.log_prob + log_prob_correction)) else: do_accept = uniform < jnp.exp( beta.reshape( (-1, )) * (proposal_log_prob - s.log_prob)) # do_accept must match ndim of proposal and state (which is 2) s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ) n_accepted_per_beta = s.n_accepted_per_beta + do_accept.reshape( (sampler.n_chains, sampler.n_replicas)) s.log_prob = jax.numpy.where(do_accept.reshape(-1), proposal_log_prob, s.log_prob) # exchange betas # randomly decide if every set of replicas should be swapped in even or odd order swap_order = jax.random.randint( key3, minval=0, maxval=2, shape=(sampler.n_chains, ), ) # 0 or 1 iswap_order = jnp.mod(swap_order + 1, 2) # 1 or 0 # indices of even swapped elements (per-row) idxs = jnp.arange(0, sampler.n_replicas, 2).reshape( (1, -1)) + swap_order.reshape((-1, 1)) # indices off odd swapped elements (per-row) inn = (idxs + 1) % sampler.n_replicas # for every rows of the input, swap elements at idxs with elements at inn @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def swap_rows(beta_row, idxs, inn): proposed_beta = jax.ops.index_update( beta_row, idxs, beta_row[inn], unique_indices=True, indices_are_sorted=True, ) proposed_beta = jax.ops.index_update( proposed_beta, inn, beta_row[idxs], unique_indices=True, indices_are_sorted=False, ) return proposed_beta proposed_beta = swap_rows(beta, idxs, inn) @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def compute_proposed_prob(prob, idxs, inn): prob_rescaled = prob[idxs] + prob[inn] return prob_rescaled # compute the probability of the swaps log_prob = (proposed_beta - state.beta) * s.log_prob.reshape( (sampler.n_chains, sampler.n_replicas)) prob_rescaled = jnp.exp( compute_proposed_prob(log_prob, idxs, inn)) prob_rescaled = jnp.exp( compute_proposed_prob(log_prob, idxs, inn)) uniform = jax.random.uniform(key4, shape=(sampler.n_chains, sampler.n_replicas // 2)) do_swap = uniform < prob_rescaled do_swap = jnp.dstack((do_swap, do_swap)).reshape( (-1, sampler.n_replicas)) # concat along last dimension # roll if swap_ordeer is odd @partial(jax.vmap, in_axes=(0, 0), out_axes=0) def fix_swap(do_swap, swap_order): return jax.lax.cond(swap_order == 0, lambda x: x, lambda x: jnp.roll(x, 1), do_swap) do_swap = fix_swap(do_swap, swap_order) # jax.experimental.host_callback.id_print(state.beta) # jax.experimental.host_callback.id_print(proposed_beta) new_beta = jax.numpy.where(do_swap, proposed_beta, beta) def cb(data): _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data print("--------.---------.---------.--------") print(" cur beta:\n", _bt) print("proposed beta:\n", _pbt) print(" new beta:\n", new_beta) print("swaporder :", so) print("do_swap :\n", do_swap) print("log_prob;\n", log_prob) print("prob_rescaled;\n", prob) return new_beta # new_beta = hcb.call( # cb, # ( # beta, # proposed_beta, # new_beta, # swap_order, # do_swap, # log_prob, # prob_rescaled, # ), # result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype), # ) # s.beta = new_beta swap_order = swap_order.reshape(-1) beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i], in_axes=(0, 0), out_axes=0)(do_swap, state.beta_0_index) proposed_beta_0_index = jnp.mod( state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) * (-jnp.mod(state.beta_0_index, 2) * 2 + 1), sampler.n_replicas, ) s.beta_0_index = jnp.where(beta_0_moved, proposed_beta_0_index, s.beta_0_index) # swap acceptances swapped_n_accepted_per_beta = swap_rows( n_accepted_per_beta, idxs, inn) s.n_accepted_per_beta = jax.numpy.where( do_swap, swapped_n_accepted_per_beta, n_accepted_per_beta, ) # Update statistics to compute diffusion coefficient of replicas # Total exchange steps performed delta = s.beta_0_index - s.beta_position s.beta_position = s.beta_position + delta / ( state.exchange_steps + i) delta2 = s.beta_0_index - s.beta_position s.beta_diffusion = s.beta_diffusion + delta * delta2 new_state = state.replace( rng=new_rng, σ=s.σ, # n_accepted=s.accepted, n_samples=state.n_samples + sampler.n_sweeps * sampler.n_chains, beta=s.beta, beta_0_index=s.beta_0_index, beta_position=s.beta_position, beta_diffusion=s.beta_diffusion, exchange_steps=state.exchange_steps + sampler.n_sweeps, n_accepted_per_beta=s.n_accepted_per_beta, ) offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas, sampler.n_replicas) return new_state, new_state.σ[new_state.beta_0_index + offsets, :] def __repr__(sampler): return ("MetropolisPTSampler(" + "\n hilbert = {},".format(sampler.hilbert) + "\n rule = {},".format(sampler.rule) + "\n n_chains = {},".format(sampler.n_chains) + "\n machine_power = {},".format(sampler.machine_pow) + "\n reset_chain = {},".format(sampler.reset_chain) + "\n n_sweeps = {},".format(sampler.n_sweeps) + "\n dtype = {},".format(sampler.dtype) + ")") def __str__(sampler): return ("MetropolisPTSampler(" + "rule = {}, ".format(sampler.rule) + "n_chains = {}, ".format(sampler.n_chains) + "machine_power = {}, ".format(sampler.machine_pow) + "reset_chain = {}, ".format(sampler.reset_chain) + "n_sweeps = {}, ".format(sampler.n_sweeps) + "dtype = {})".format(sampler.dtype))
class Point: x: float y: float meta: Any = struct.field(pytree_node=False)
class QuantizedValue: """State associated with quantized value.""" quantized: chex.Array diagonal: chex.Array # Diagonal (if extract_diagonal is set) bucket_size: chex.Array quantized_dtype: jnp.dtype = struct.field( pytree_node=False) # Dtype for the quantized value. extract_diagonal: bool = struct.field( pytree_node=False) # In case its centered. shape: Any = struct.field(pytree_node=False) # Shape of the tensor. @classmethod def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): if isinstance(fvalue, list) and not fvalue: return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( fvalue, quantized_dtype, extract_diagonal) return QuantizedValue(quantized, diagonal_fvalue, bucket_size, quantized_dtype, extract_diagonal, list(quantized.shape)) # Quantization is from Lingvo JAX optimizers. # We extend it for int16 quantization of PSD matrices. @classmethod def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): """Returns quantized value and the bucket.""" if quantized_dtype == jnp.float32: return fvalue, [], [] elif quantized_dtype == jnp.bfloat16: return fvalue.astype(jnp.bfloat16), [], [] float_dtype = fvalue.dtype if quantized_dtype == jnp.int8: # value -128 is not used. num_buckets = jnp.array(127.0, dtype=float_dtype) elif quantized_dtype == jnp.int16: # value -32768 is not used. num_buckets = jnp.array(32767.0, dtype=float_dtype) else: raise ValueError(f'Quantized dtype {quantized_dtype} not supported.') # max value is mapped to num_buckets if extract_diagonal and fvalue.ndim != 2: raise ValueError( f'Input array {fvalue} must be 2D to work with extract_diagonal.') diagonal_fvalue = [] if extract_diagonal: diagonal_fvalue = jnp.diag(fvalue) # Remove the diagonal entries. fvalue = fvalue - jnp.diag(diagonal_fvalue) # TODO(rohananil): Extend this by making use of information about the blocks # SM3 style which will be useful for diagonal statistics # We first decide the scale. if fvalue.ndim < 1: raise ValueError( f'Input array {fvalue} must have a strictly positive number of ' 'dimensions.') max_abs = jnp.max(jnp.abs(fvalue), axis=0) bucket_size = max_abs / num_buckets bs_expanded = bucket_size[jnp.newaxis, Ellipsis] # To avoid divide by 0.0 bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)) ratio = fvalue / bs_nonzero # We use rounding to remove bias. quantized = jnp.round(ratio) return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size def to_float(self): """Returns the float value.""" if isinstance(self.quantized, list) and not self.quantized: return self.quantized if self.quantized_dtype == jnp.float32: return self.quantized if self.quantized_dtype == jnp.bfloat16: return self.quantized.astype(jnp.float32) float_dtype = self.bucket_size.dtype bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] val = self.quantized.astype(float_dtype) * bucket_size if self.extract_diagonal: val += jnp.diag(self.diagonal) return val
class Glue(struct.PyTreeNode): sentences: Tuple[str, Union[str, None]] = struct.field(pytree_node=False) num_labels: int = struct.field(pytree_node=False) is_regression: bool = struct.field(pytree_node=False)