class QGTJacobianDenseT(LinearOperator):
    """
    Semi-lazy representation of an S Matrix behaving like a linear operator.

    The matrix of gradients O is computed on initialisation, but not S,
    which can be computed by calling :code:`to_dense`.
    The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in
    the field `sr`.
    """

    O: jnp.ndarray = Uninitialized
    """Gradients O_ij = ∂log ψ(σ_i)/∂p_j of the neural network
    for all samples σ_i at given values of the parameters p_j
    Average <O_j> subtracted for each parameter
    Divided through with sqrt(#samples) to normalise S matrix
    If scale is not None, columns normalised to unit norm
    """

    scale: Optional[jnp.ndarray] = None
    """If not None, contains 2-norm of each column of the gradient matrix,
    i.e., the sqrt of the diagonal elements of the S matrix
    """

    mode: str = struct.field(pytree_node=False, default=Uninitialized)
    """Differentiation mode:
        - "real": for real-valued R->R and C->R ansatze, splits the complex inputs
                  into real and imaginary part.
        - "complex": for complex-valued R->C and C->C ansatze, splits the complex
                  inputs and outputs into real and imaginary part
        - "holomorphic": for any ansatze. Does not split complex values.
        - "auto": autoselect real or complex.
    """

    _in_solve: bool = struct.field(pytree_node=False, default=False)
    """Internal flag used to signal that we are inside the _solve method and matmul should
    not take apart into real and complex parts the other vector"""
    def __matmul__(
            self, vec: Union[PyTree,
                             jnp.ndarray]) -> Union[PyTree, jnp.ndarray]:
        return _matmul(self, vec)

    def _solve(self,
               solve_fun,
               y: PyTree,
               *,
               x0: Optional[PyTree] = None) -> PyTree:
        return _solve(self, solve_fun, y, x0=x0)

    def to_dense(self) -> jnp.ndarray:
        """
        Convert the lazy matrix representation to a dense matrix representation.

        Returns:
            A dense matrix representation of this S matrix.
        """
        return _to_dense(self)

    def __repr__(self):
        return (f"QGTJacobianDense(diag_shift={self.diag_shift}, "
                f"scale={self.scale}, mode={self.mode})")
Beispiel #2
0
class TiedAutoEncoder:

  latents: int = struct.field(False)
  features: int = struct.field(False)

  def __call__(self, scope, x):
    z = self.encode(scope, x)
    return self.decode(scope, z)

  def encode(self, scope, x):
    assert x.shape[-1] == self.features
    return self._tied(nn.dense)(scope, x, self.latents, bias=False)

  def decode(self, scope, z):
    assert z.shape[-1] == self.latents
    return self._tied(nn.dense, transpose=True)(scope, z, self.features, bias=False)
  
  def _tied(self, fn, transpose=False):
    if not transpose:
      return fn

    def trans(variables):
      if 'param' not in variables:
        return variables
      params = variables['param']
      params['kernel'] = params['kernel'].T
      return variables

    return lift.transform_module(
        fn, trans_in_fn=trans, trans_out_fn=trans)
Beispiel #3
0
class GaussianProcess:
    index_points: jnp.ndarray
    mean_function: Callable = struct.field(pytree_node=False)
    kernel_function: Callable = struct.field(pytree_node=False)
    jitter: float

    def marginal(self):
        kxx = self.kernel_function(self.index_points, self.index_points)
        chol_kxx = jnp.linalg.cholesky(utils.diag_shift(kxx, self.jitter))
        mean = self.mean_function(self.index_points)
        return distributions.MultivariateNormalTriL(mean, chol_kxx)

    def posterior_gp(self, y, x_new, observation_noise_variance, jitter=None):
        """ Returns a new GP conditional on y. """
        cond_kernel_fn, _ = kernels.SchurComplementKernelProvider.init(
            None, self.kernel_function, self.index_points,
            observation_noise_variance)

        marginal = self.marginal()

        def cond_mean_fn(x):
            k_xnew_x = self.kernel_function(x, self.index_points)
            return (self.mean_function(x) + k_xnew_x @ jscipy.linalg.cho_solve(
                (cond_kernel_fn.divisor_matrix_cholesky, True),
                y - marginal.mean))

        jitter = jitter if jitter else self.jitter
        return GaussianProcess(x_new, cond_mean_fn, cond_kernel_fn, jitter)
class LocalShardedParameterStats:
  """State associated to each parameter of the model being trained."""
  diagonal_statistics: chex.Array  # Accumulator for diagonal preconditioner
  diagonal_momentum: chex.Array  # Momentum for the diagonal preconditioner
  momentum: chex.Array  # Momentum for the shampoo preconditioner
  index_start: np.int32 = struct.field(
      pytree_node=False)  # Index into global statistics array
  sizes: Any = struct.field(pytree_node=False)  # Sizes of the statistics.
Beispiel #5
0
class QGTOnTheFlyT(LinearOperator):
    """
    Lazy representation of an S Matrix computed by performing 2 jvp
    and 1 vjp products, using the variational state's model, the
    samples that have already been computed, and the vector.

    The S matrix is not computed yet, but can be computed by calling
    :code:`to_dense`.
    The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in
    the field `sr`.
    """

    apply_fun: Callable[[PyTree, jnp.ndarray],
                        jnp.ndarray] = struct.field(pytree_node=False,
                                                    default=Uninitialized)
    """The forward pass of the Ansatz."""

    params: PyTree = Uninitialized
    """The first input to apply_fun (parameters of the ansatz)."""

    samples: jnp.ndarray = Uninitialized
    """The second input to apply_fun (points where the ansatz is evaluated)."""

    model_state: Optional[PyTree] = None
    """Optional state of the ansataz."""

    centered: bool = struct.field(pytree_node=False, default=True)
    """Uses S=⟨ΔÔᶜΔÔ⟩ if True (default), S=⟨ÔᶜΔÔ⟩ otherwise. The two forms are 
    mathematically equivalent, but might lead to different results due to numerical
    precision. The non-centered variant should be approximately 33% faster.
    """
    def __post_init__(self):
        super().__post_init__()

        if jnp.ndim(self.samples) != 2:
            samples_r = self.samples.reshape((-1, self.samples.shape[-1]))
            object.__setattr__(self, "samples", samples_r)

    def __matmul__(self, y):
        return onthefly_mat_treevec(self, y)

    def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree],
               **kwargs) -> PyTree:
        return _solve(self, solve_fun, y, x0=x0)

    def to_dense(self) -> jnp.ndarray:
        """
        Convert the lazy matrix representation to a dense matrix representation.s

        Returns:
            A dense matrix representation of this S matrix.
        """
        return _to_dense(self)
Beispiel #6
0
class CustomRuleNumpy(MetropolisRule):
    operator: Any = struct.field(pytree_node=False)
    weight_list: Any = struct.field(pytree_node=False, default=None)

    def __post_init__(self):
        # Raise errors if hilbert is not an Hilbert
        if not isinstance(self.operator, AbstractOperator):
            raise TypeError(
                "Argument to CustomRuleNumpy must be a valid operator.".format(
                    type(operator)
                )
            )
        _check_operators(self.operator.operators)

        if self.weight_list is not None:
            if self.weight_list.shape != (self.operator.n_operators,):
                raise ValueError("move_weights have the wrong shape")
            if self.weight_list.min() < 0:
                raise ValueError("move_weights must be positive")
        else:
            object.__setattr__(
                self,
                "weight_list",
                np.ones(self.operator.n_operators, dtype=np.float32),
            )

        object.__setattr__(
            self, "weight_list", self.weight_list / self.weight_list.sum()
        )

    def init_state(rule, sampler, machine, params, key):
        return CustomRuleState(
            sections=np.empty(sampler.n_batches, dtype=np.int32),
            rand_op_n=np.empty(sampler.n_batches, dtype=np.int32),
            weight_cumsum=rule.weight_list.cumsum(),
        )

    def transition(rule, sampler, machine, parameters, state, rng, σ):
        rule_state = state.rule_state

        _pick_random_and_init(
            σ.shape[0],
            rule_state.weight_cumsum,
            out=rule_state.rand_op_n,
        )

        σ_conns, mels = rule.operator.get_conn_filtered(
            state.σ, rule_state.sections, rule_state.rand_op_n
        )

        _choose_and_return(
            state.σ1, σ_conns, mels, rule_state.sections, state.log_prob_corr
        )
Beispiel #7
0
    class TrainState(train_state.TrainState):
        """Train state with an Optax optimizer.

        The two functions below differ depending on whether the task is classification
        or regression.

        Args:
          logits_fn: Applied to last layer to obtain the logits.
          loss_fn: Function to compute the loss.
        """

        logits_fn: Callable = struct.field(pytree_node=False)
        loss_fn: Callable = struct.field(pytree_node=False)
class TrainingState(struct.PyTreeNode):
    """Training state for distance model."""
    step: int
    epoch: int
    best_loss: Dict[str, float]
    encoder_fn: Callable[[Variables, jnp.ndarray],
                         jnp.ndarray] = (struct.field(pytree_node=False))
    distance_fn: Callable[[Variables, jnp.ndarray],
                          jnp.ndarray] = (struct.field(pytree_node=False))
    domain_discriminator_fn: Callable[[Variables, jnp.ndarray],
                                      jnp.ndarray] = (struct.field(
                                          pytree_node=False))
    distance_optimizer: optax.OptState
    domain_optimizer: optax.OptState
Beispiel #9
0
class MultiScaleRBIGBlockInit:
    init_functions: List[dataclass]
    filter_shape: Tuple[int, int] = struct.field(pytree_node=False)
    image_shape: Tuple = struct.field(pytree_node=False)

    def __post_init__(self):
        self.ms_reshape = init_scale_function(self.filter_shape,
                                              self.image_shape)

    def forward_and_params(self, inputs: Array) -> Tuple[Array, Array]:

        # rescale data
        inputs = self.ms_reshape.forward(inputs)

        outputs = inputs
        bijectors = []

        # loop through bijectors
        for ibijector in self.init_functions:

            # transform and params
            outputs, ibjector = ibijector.bijector_and_transform(outputs)

            # accumulate params
            bijectors.append(ibjector)

        # unrescale data
        outputs = self.ms_reshape.inverse(outputs)

        # create bijector chain
        bijectors = MultiScaleBijector(
            bijectors=bijectors,
            filter_shape=self.filter_shape,
            image_shape=self.image_shape,
        )

        return outputs, bijectors

    def forward(self, inputs: Array) -> Array:

        # rescale data
        inputs = self.ms_reshape.forward(inputs)

        outputs = inputs
        for ibijector in self.init_functions:
            outputs = ibijector.transform(outputs)

        # unrescale data
        outputs = self.ms_reshape.inverse(outputs)
        return outputs
class EmaState:
    decay: float = struct.field(pytree_node=False, default=0.)
    params: flax.core.FrozenDict = None
    model_state: flax.core.FrozenDict = None

    @staticmethod
    def create(decay, params, model_state):
        """Initialize ema state"""
        if decay == 0.:
            # default state == disabled
            return EmaState()
        ema_params = jax.tree_map(lambda x: x, params)
        ema_model_state = jax.tree_map(lambda x: x, model_state)
        return EmaState(decay, ema_params, ema_model_state)

    def update(self, new_params, new_model_state):
        if self.decay == 0.:
            return self.replace(params=None, model_state=None)

        new_params = jax.tree_multimap(
            lambda ema, p: ema * self.decay + (1. - self.decay) * p,
            self.params, new_params)
        new_model_state = jax.tree_multimap(
            lambda ema, s: ema * self.decay + (1. - self.decay) * s,
            self.model_state, new_model_state)
        return self.replace(params=new_params, model_state=new_model_state)
Beispiel #11
0
class ConditionalGaussianizationFlow(BijectorChain):
    bijectors: Iterable[Bijector]
    base_dist: Distribution = struct.field(pytree_node=False)
    encoder: Callable

    def score_samples(self, inputs, outputs):

        # forward propagation
        z, log_det = self.forward_and_log_det(inputs)

        # encode params
        y_dist = self.encoder.forward(outputs)

        # calculate latent probability
        latent_prob = y_dist.log_prob(z)
        # calculate log probability
        log_prob = latent_prob.sum(axis=1) + log_det.sum(axis=1)

        return log_prob

    def score(self, inputs, outputs):
        return -jnp.mean(self.score_samples(inputs, outputs))

    def sample(self, outputs: Array, seed: int, n_samples: int):
        # encode params
        y_dist = self.encoder.forward(outputs)

        # generate Gaussian samples
        X_g_samples = y_dist.sample(seed=seed, sample_shape=n_samples)
        # # inverse transformation
        return self.inverse(X_g_samples)
Beispiel #12
0
class SRLazy(SR):
    """
    Base class holding the parameters for the iterative solution of the
    SR system x = ⟨S⟩⁻¹⟨F⟩, where S is a lazy linear operator

    Tolerances are applied according to the formula
    :code:`norm(residual) <= max(tol*norm(b), atol)`
    """

    tol: float = 1.0e-5
    """Relative tolerance for convergences."""

    atol: float = 0.0
    """Absolutes tolerance for convergences."""

    maxiter: int = None
    """Maximum number of iterations. Iteration will stop after maxiter steps even 
    if the specified tolerance has not been achieved.
    """

    M: Optional[Union[Callable, Array]] = None
    """Preconditioner for A. The preconditioner should approximate the inverse of A. 
    Effective preconditioning dramatically improves the rate of convergence, which implies 
    that fewer iterations are needed to reach a given error tolerance.
    """

    centered: bool = struct.field(pytree_node=False, default=True)
    """Uses S=⟨ΔÔᶜΔÔ⟩ if True (default), S=⟨ÔᶜΔÔ⟩ otherwise. The two forms are 
    mathematically equivalent, but might lead to different results due to numerical
    precision. The non-centered variaant should bee approximately 33% faster.
    """
    def create(self, *args, **kwargs):
        return LazySMatrix(*args, **kwargs)
Beispiel #13
0
class TrainState:
    """Container for misc training state that's not handeled by the optimizer.
  This includes:
  - The base RNG key for each step, replicated across devices.
  - Any metrics output by the training step (that are then logged to the history
    object)
  """
    history: TrainStateHistory = struct.field(pytree_node=False)
    rng: Any
    step: Any
    metrics: Any

    def take_step(self, optimizer, grad, metrics, rng):
        if isinstance(optimizer.state, list):
            step = optimizer.state[0].step
        else:
            step = optimizer.state.step
        new_optimizer = optimizer.apply_gradient(
            grad, learning_rate=self.history.learning_rate_fn(step))
        new_train_state = self.replace(rng=rng, step=step, metrics=metrics)
        return new_optimizer, new_train_state

    def write_history(self):
        step = self.step[0]
        self.history.write(step, self.metrics)
        return self.replace(step=None, metrics=None)
class TrainState():
  """Container for misc training state that's not handeled by the optimizer.

  This includes:
  - The base RNG key for each step, replicated across devices.
  - Any metrics output by the training step (that are then logged to the history
    object)
  """
  rng: Any
  step: Any
  metrics: Any
  history: TrainStateHistory = struct.field(pytree_node=False)

  def take_step(self, optimizer, grad, metrics, rng):
    step = optimizer.state.step
    new_optimizer = optimizer.apply_gradient(
        grad, learning_rate=self.history.learning_rate_fn(step))
    # TODO(marcvanzee): Remove this when b/162398046 is fixed.
    new_train_state = self.replace(rng=rng, step=step, metrics=metrics)  # pytype: disable=attribute-error
    return new_optimizer, new_train_state

  def write_history(self):
    step = self.step[0]
    self.history.write(step, self.metrics)
    return self.replace(step=None, metrics=None)  # pytype: disable=attribute-error
Beispiel #15
0
class TrainState(struct.PyTreeNode):
    """Simple train state for the common case with a single Optax optimizer.
  """
    # New with respect to Flax definition:
    ema_params: ParamDict

    # Same as in Flax definition of train_state:
    step: int
    params: ParamDict
    tx: optax.GradientTransformation = struct.field(pytree_node=False)
    opt_state: optax.OptState

    def apply_gradients(self, *, grads, ema_momentum=0.9999, **kwargs):
        """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.

    Note that internally this function calls `.tx.update()` followed by a call
    to `optax.apply_updates()` to update `params` and `opt_state`.

    Args:
      grads: Gradients that have the same pytree structure as `.params`.
      ema_momentum: momentum for EMA updates.
      **kwargs: Additional dataclass attributes that should be `.replace()`-ed.

    Returns:
      An updated instance of `self` with `step` incremented by one, `params`
      and `opt_state` updated by applying `grads`, and additional attributes
      replaced as specified by `kwargs`.
    """
        updates, new_opt_state = self.tx.update(grads, self.opt_state,
                                                self.params)
        new_params = optax.apply_updates(self.params, updates)

        def update_ema(ema, p):
            return ema_momentum * ema + (1. - ema_momentum) * p

        new_ema_params = jax.tree_multimap(update_ema, self.ema_params,
                                           new_params)

        return self.replace(
            step=self.step + 1,
            params=new_params,
            ema_params=new_ema_params,
            opt_state=new_opt_state,
            **kwargs,
        )

    @classmethod
    def create(cls, *, params, tx, **kwargs):
        """Creates a new instance with `step=0` and initialized `opt_state`."""
        opt_state = tx.init(params)
        return cls(
            step=0,
            ema_params=params,
            # Deepcopy to avoid donating the same object twice in train_step.
            params=copy.deepcopy(params),
            tx=tx,
            opt_state=opt_state,
            **kwargs,
        )
Beispiel #16
0
class Model:
    """DEPRECATION WARNING:
  The `flax.nn` module is Deprecated, use `flax.linen` instead. 
  Learn more and find an upgrade guide at 
  https://github.com/google/flax/blob/master/flax/linen/README.md

  A Model contains the model parameters, state and definition."""

    module: Type[Module] = struct.field(pytree_node=False)
    params: Any = struct.field(pytree_node=True)

    def __call__(self, *args, **kwargs):
        return self.module.call(self.params, *args, **kwargs)

    def truncate_at(self, module_path):
        """Truncates the model by returning the outputs of the given sub-module.

    Args:
      module_path: the full name of the module (e.g. '/module/sub_module').
        A list or dict of module paths can be provided to obtain the
        intermediate outputs of multiple modules.
    Returns:
      A new model with the truncated outputs. If module_path is a pytree of
      paths the outputs will be have the same structure where each path is
      replaced by the corresponding intermediate output.
    """
        truncated_module_cls = TruncatedModule.partial(
            wrapped_module=self.module, truncate_path=module_path)
        return self.replace(module=truncated_module_cls)

    def __getattr__(self, name):
        value = getattr(self.module, name)
        if inspect.isclass(value) and issubclass(value, Module):

            def wrapper(*args, **kwargs):
                return value.call(self.params, *args, **kwargs)

            return wrapper
        raise AttributeError(f'No attribute named "{name}".')

    def __hash__(self):
        # Jax will call hash when the model is passed to a function transform.
        # The compiled function should not be shared among model instances because
        # it closes over the specific parameters of this model instance.
        return id(self)
Beispiel #17
0
class Kernel:
    kernel_fn: Callable = struct.field(pytree_node=False)

    def apply(self, x, x2):
        return self.kernel_fn(x, x2)

    def __call__(self, x, x2=None):
        x2 = x if x2 is None else x2
        return self.apply(x, x2)
Beispiel #18
0
class SRLazyGMRES(SRLazy):
    """
    Computes x = ⟨S⟩⁻¹⟨F⟩ by using an iterative GMRES method.

    See `Jax docs <https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.gmres.html#jax.scipy.sparse.linalg.gmres>`_
    for more informations.
    """

    restart: int = struct.field(pytree_node=False, default=20)
    """Size of the Krylov subspace (“number of iterations”) built between restarts. 
    GMRES works by approximating the true solution x as its projection into a Krylov 
    space of this dimension - this parameter therefore bounds the maximum accuracy 
    achievable from any guess solution. Larger values increase both number of iterations 
    and iteration cost, but may be necessary for convergence. The algorithm terminates early 
    if convergence is achieved before the full subspace is built. 
    Default is 20
    """

    solve_method: str = struct.field(pytree_node=False, default="batched")
    """(‘incremental’ or ‘batched’) – The ‘incremental’ solve method builds a QR 
    decomposition for the Krylov subspace incrementally during the GMRES process 
    using Givens rotations. This improves numerical stability and gives a free estimate 
    of the residual norm that allows for early termination within a single “restart”. In 
    contrast, the ‘batched’ solve method solves the least squares problem from scratch at 
    the end of each GMRES iteration. It does not allow for early termination, but has much 
    less overhead on GPUs.
    """
    def solve_fun(self):
        return partial(
            jax.scipy.sparse.linalg.gmres,
            tol=self.tol,
            atol=self.atol,
            maxiter=self.maxiter,
            M=self.M,
            restart=self.restart,
            solve_method=self.solve_method,
        )
Beispiel #19
0
class ConditionalModel:
    params: dict
    model: Callable = struct.field(pytree_node=False)

    def forward(self, inputs) -> Tuple[Array, Array]:
        # forward pass for params
        outputs = self.model.apply(self.params, inputs)
        # split params
        split = outputs.shape[1] // 2

        # compute means and log stds
        means = outputs[..., :split]
        log_stds = outputs[..., split:]
        dist = LogStddevNormal(loc=means, log_scale=log_stds)

        return dist
Beispiel #20
0
class HamiltonianRuleNumpy(MetropolisRule):
    """
    Rule for Numpy sampler backend proposing moves according to the terms in an operator.

    In this case, the transition matrix is taken to be:

    .. math::

       T( \\mathbf{s} \\rightarrow \\mathbf{s}^\\prime) = \\frac{1}{\\mathcal{N}(\\mathbf{s})}\\theta(|H_{\\mathbf{s},\\mathbf{s}^\\prime}|),

    """

    operator: AbstractOperator = struct.field(pytree_node=False)
    """The (hermitian) operator giving the transition amplitudes."""
    def __post_init__(self):
        # Raise errors if hilbert is not an Hilbert
        if not isinstance(self.operator, AbstractOperator):
            raise TypeError(
                "Argument to HamiltonianRuleNumpy must be a valid operator.".
                format(type(operator)))

    def init_state(rule, sampler, machine, params, key):
        if sampler.hilbert != rule.operator.hilbert:
            raise ValueError(f"""
            The hilbert space of the sampler ({sampler.hilbert}) and the hilbert space
            of the operator ({rule.operator.hilbert}) for HamiltonianRule must be the same.
            """)

        return HamiltonianRuleState(
            sections=np.empty(sampler.n_batches, dtype=np.int32))

    def transition(rule, sampler, machine, parameters, state, rng, σ):
        σ = state.σ
        σ1 = state.σ1
        log_prob_corr = state.log_prob_corr

        sections = state.rule_state.sections
        σp = rule.operator.get_conn_flattened(σ, sections)[0]

        rand_vec = rng.uniform(0, 1, size=σ.shape[0])

        _choose(σp, sections, σ1, log_prob_corr, rand_vec)
        rule.operator.n_conn(σ1, sections)
        log_prob_corr -= np.log(sections)

    def __repr__(self):
        return f"HamiltonianRuleNumpy({self.operator})"
Beispiel #21
0
class EmaState:
    decay: float = struct.field(pytree_node=False, default=0.)
    variables: flax.core.FrozenDict[str, Any] = None

    @staticmethod
    def create(decay, variables):
        """Initialize ema state"""
        if decay == 0.:
            # default state == disabled
            return EmaState()
        ema_variables = jax.tree_map(lambda x: x, variables)
        return EmaState(decay, ema_variables)

    def update(self, new_variables):
        if self.decay == 0.:
            return self.replace(variables=None)
        new_ema_variables = jax.tree_multimap(
            lambda ema, p: ema * self.decay + (1. - self.decay) * p,
            self.variables, new_variables)
        return self.replace(variables=new_ema_variables)
Beispiel #22
0
class Model:
    """A Model contains the model paramaters, state and definition."""

    module: Module = struct.field(pytree_node=False)
    params: Any

    def __call__(self, *args, **kwargs):
        return self.module.call(self.params, *args, **kwargs)

    def truncate_at(self, module_path):
        """Truncate the model by returning the outputs of the given sub-module.

    Args:
      module_path: the full name of the module (eg. '/module/sub_module').
        A list or dict of module paths can be provided to obtain the
        intermediate outputs of multiple modules.
    Returns:
      A new model with the truncated outputs. If module_path is a pytree of
      paths the outputs will be have the same structure where each path is
      replaced by the corresponding intermediate output.
    """
        truncated_module_cls = TruncatedModule.partial(
            wrapped_module=self.module, truncate_path=module_path)
        return self.replace(module=truncated_module_cls)

    def __getattr__(self, name):
        value = getattr(self.module, name)
        if inspect.isclass(value) and issubclass(value, Module):

            def wrapper(*args, **kwargs):
                return value.call(self.params, *args, **kwargs)

            return wrapper
        raise AttributeError(f'No attribute named "{name}".')

    def __hash__(self):
        # Jax will call hash when model is passed to a function transform.
        # the compiled function should not be shared among model instances because
        # it closes over the specific parameters of this model instance.
        return id(self)
Beispiel #23
0
class QGTOnTheFlyT(LinearOperator):
    """
    Lazy representation of an S Matrix computed by performing 2 jvp
    and 1 vjp products, using the variational state's model, the
    samples that have already been computed, and the vector.

    The S matrix is not computed yet, but can be computed by calling
    :code:`to_dense`.
    The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contained in
    the field `sr`.
    """

    _mat_vec: Callable[[PyTree, float], PyTree] = Uninitialized
    """The S matrix-vector product as generated by mat_vec_factory.
    It's a jax.Partial, so can be used as pytree_node."""

    _params: PyTree = Uninitialized
    """The first input to apply_fun (parameters of the ansatz).
    Only used as a shape placeholder."""

    _chunking: bool = struct.field(pytree_node=False, default=False)
    """Whether the implementation with chunks is used which currently does not support vmapping over it"""

    def __matmul__(self, y):
        return onthefly_mat_treevec(self, y)

    def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs) -> PyTree:
        return _solve(self, solve_fun, y, x0=x0)

    def to_dense(self) -> jnp.ndarray:
        """
        Convert the lazy matrix representation to a dense matrix representation.

        Returns:
            A dense matrix representation of this S matrix.
        """
        return _to_dense(self)

    def __repr__(self):
        return f"QGTOnTheFly(diag_shift={self.diag_shift})"
Beispiel #24
0
class GraphParameters:
    """Holds a graph's characteristics.

  Attributes:
    node_vocab_size: Number of different node types.
    num_relation_types: Number of different edge types.
    node_feature_dim: Dimension of the fixed representation of each node.
    node_feature_kind: Determines whether the node features are categorical or
      real.
    task_vocab_size: Number of different types of tasks.
    task_feature_dim: Dimension of the fixed representation of each task.
    task_feature_kind: Determines whether the task features are categorical or
      real.
  """
    node_vocab_size: int = struct.field(pytree_node=False)
    num_relation_types: int = struct.field(pytree_node=False)
    node_feature_dim: int = struct.field(pytree_node=False)
    node_feature_kind: NodeFeatureKind = struct.field(pytree_node=False)
    task_vocab_size: int = struct.field(pytree_node=False)
    task_feature_dim: int = struct.field(pytree_node=False)
    task_feature_kind: NodeFeatureKind = struct.field(pytree_node=False)
Beispiel #25
0
class LazySMatrix:
    """
    Lazy representation of an S Matrix behving like a linear operator.

    The S matrix is not computed yet, but can be computed by calling
    :code:`to_dense`.
    The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in
    the field `sr`.
    """

    apply_fun: Callable[[PyTree, jnp.ndarray],
                        jnp.ndarray] = struct.field(pytree_node=False)
    """The forward pass of the Ansatz."""

    params: PyTree
    """The first input to apply_fun (parameters of the ansatz)."""

    samples: jnp.ndarray
    """The second input to apply_fun (points where the ansatz is evaluated)."""

    sr: SRLazy
    """Parameters for the solution of the system."""

    model_state: Optional[PyTree] = None
    """Optional state of the ansataz."""

    x0: Optional[PyTree] = None
    """Optional initial guess for the iterative solution."""
    def __matmul__(self, vec):
        return lazysmatrix_mat_treevec(self, vec)

    def __rtruediv__(self, y):
        return self.solve(y)

    def solve(self, y: PyTree, x0: Optional[PyTree] = None) -> PyTree:
        """
        Solve the linear system x=⟨S⟩⁻¹⟨y⟩ with the chosen iterataive solver.

        Args:
            y: the vector y in the system above.
            x0: optional initial guess for the solution.

        Returns:
            x: the PyTree solving the system.
            info: optional additional informations provided by the solver. Might be
                None if there are no additional informations provided.
        """
        if x0 is None:
            x0 = self.x0

        out, info = apply_onthefly(
            self,
            y,
            x0,
        )

        return out, info

    @jax.jit
    def to_dense(self) -> jnp.ndarray:
        """
        Convert the lazy matrix representation to a dense matrix representation.s

        Returns:
            A dense matrix representation of this S matrix.
        """
        Npars = nkjax.tree_size(self.params)
        I = jax.numpy.eye(Npars)
        return jax.vmap(lambda S, x: self @ x, in_axes=(None, 0))(self, I)
Beispiel #26
0
class Sampler(abc.ABC):
    """
    Abstract base class for all samplers.

    It contains the fields that all of them should posses, defining the common
    API.
    Note that fields marked with pytree_node=False are treated as static arguments
    when jitting.
    """

    hilbert: AbstractHilbert = struct.field(pytree_node=False)
    """Hilbert space to be sampled."""

    n_chains: int = struct.field(pytree_node=False, default=16)
    """Number of batches along the chain"""

    machine_pow: int = struct.field(default=2)
    """Exponent of the pdf sampled"""

    dtype: type = struct.field(pytree_node=False, default=np.float64)
    """DType of the states returned."""
    def __post_init__(self):
        # Raise errors if hilbert is not an Hilbert
        if not isinstance(self.hilbert, AbstractHilbert):
            raise ValueError(
                "hilbert must be a subtype of netket.hilbert.AbstractHilbert, "
                + "instead, type {} is not.".format(type(self.hilbert)))

        if not isinstance(self.n_chains, int) and self.n_chains >= 0:
            raise ValueError("n_chains must be a positivee integer")

        # if not isinstance(self.machine_pow, int) and self.machine_pow>= 0:
        #    raise ValueError("machine_pow must be a positivee integer")

    @property
    def n_batches(self) -> int:
        """
        The batch size of the configuration $\sigma$ used by this sampler.

        In general, it is equivalent to :attr:`~Sampler.n_chains`.
        """
        return self.n_chains

    def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable:
        """
        Returns a closure with the log_pdf function encoded by this sampler.

        Note: the result is returned as an HashablePartial so that the closure
        does not trigger recompilation.

        Args:
            model: The machine, or apply_fun

        Returns:
            the log probability density function
        """
        apply_fun = get_afun_if_module(model)
        log_pdf = HashablePartial(
            lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).
            real,
            apply_fun,
        )
        return log_pdf

    def init_state(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        seed: Optional[SeedType] = None,
    ) -> SamplerState:
        """
        Creates the structure holding the state of the sampler.

        If you want reproducible samples, you should specify `seed`, otherwise the state
        will be initialised randomly.

        If running across several MPI processes, all sampler_states are guaranteed to be
        in a different (but deterministic) state.
        This is achieved by first reducing (summing) the seed provided to every MPI rank,
        then generating n_rank seeds starting from the reduced one, and every rank is
        initialized with one of those seeds.

        The resulting state is guaranteed to be a frozen python dataclass (in particular,
        a flax's dataclass), and it can be serialized using Flax serialization methods.

        Args:
            machine: a Flax module or callable with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used.

        Returns:
            The structure holding the state of the sampler. In general you should not expect
            it to be in a valid state, and should reset it before use.
        """
        key = nkjax.PRNGKey(seed)

        return sampler._init_state(get_afun_if_module(machine), parameters,
                                   nkjax.mpi_split(key))

    def reset(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> SamplerState:
        """
        Resets the state of the sampler. To be used every time the parameters are changed.

        Args:
            machine: a Flax module or callable with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If it's not provided, it will be constructed
                by calling :code:`sampler.init_state(machine, parameters)` with a random seed.

        Returns:
            A valid sampler state.
        """
        if state is None:
            state = sampler_state(sampler, machine, parameters)

        return sampler._reset(get_afun_if_module(machine), parameters, state)

    def sample_next(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples the next state in the markov chain.

        Args:
            machine: a Flax module or callable apply function with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If it's not provided, it will be constructed
                by calling :code:`sampler.reset(machine, parameters)` with a random seed.

        Returns:
            state: The new state of the sampler
            σ: The next batch of samples.
        """
        if state is None:
            state = sampler_state(sampler, machine, parameters)

        return sampler._sample_next(get_afun_if_module(machine), parameters,
                                    state)

    def sample(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        *,
        state: Optional[SamplerState] = None,
        chain_length: int = 1,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples chain_length elements along the chains.

        Arguments:
            sampler: The Monte Carlo sampler.
            machine: The model or callable to sample from (if it's a function it should have
                the signature :code:`f(parameters, σ) -> jnp.ndarray`).
            parameters: The PyTree of parameters of the model.
            state: current state of the sampler. If None, then initialises it.
            chain_length: (default=1), the length of the chains.

        Returns:
            state: The new state of the sampler
            σ: The next batch of samples.
        """

        return sample(sampler,
                      machine,
                      parameters,
                      state=state,
                      chain_length=chain_length)

    @partial(jax.jit, static_argnums=(1, 4))
    def _sample_chain(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: SamplerState,
        chain_length: int,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        _sample_next = lambda state, _: sampler.sample_next(
            machine, parameters, state)

        state, samples = jax.lax.scan(
            _sample_next,
            state,
            xs=None,
            length=chain_length,
        )

        return samples, state

    @abc.abstractmethod
    def _init_state(sampler, machine, params, seed) -> SamplerState:
        """
        Implementation of init_state for subclasses of Sampler.

        If you sub-class Sampler, you should define this and not init_state
        itself, because init_state contains some common logic.
        """
        raise NotImplementedError("init_state Not Implemented")

    @abc.abstractmethod
    def _reset(sampler, machine, parameters, state):
        """
        Implementation of reset for subclasses of Sampler.

        If you sub-class Sampler, you should define _reset and not reset
        itself, because reset contains some common logic.
        """
        raise NotImplementedError("reset Not Implemented")

    @abc.abstractmethod
    def _sample_next(sampler, machine, parameters, state=None):
        """
        Implementation of sample_next for subclasses of Sampler.

        If you sub-class Sampler, you should define _sample_next and not sample_next
        itself, because reset contains some common logic.
        """
        raise NotImplementedError("sample_next Not Implemented")
Beispiel #27
0
class MetropolisPtSampler(MetropolisSampler):
    """
    Metropolis-Hastings with Parallel Tempering sampler.

    This sampler samples an Hilbert space, producing samples off a specific dtype.
    The samples are generated according to a transition rule that must be
    specified.
    """

    n_replicas: int = struct.field(pytree_node=False, default=32)
    """The number of replicas"""
    def __init__(
        self,
        hilbert: AbstractHilbert,
        rule: MetropolisRule,
        *,
        n_replicas: int = 32,
        **kwargs,
    ):
        """
        ``MetropolisSampler`` is a generic Metropolis-Hastings sampler using
        a transition rule to perform moves in the Markov Chain.
        The transition kernel is used to generate
        a proposed state :math:`s^\prime`, starting from the current state :math:`s`.
        The move is accepted with probability

        .. math::
            A(s\\rightarrow s^\\prime) = \\mathrm{min}\\left (1,\\frac{P(s^\\prime)}{P(s)} F(e^{L(s,s^\\prime)})\\right),

        where the probability being sampled from is :math:`P(s)=|M(s)|^p. Here ::math::`M(s)` is a
        user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`,
        and :math:`L(s,s^\\prime)` is a suitable correcting factor computed by the transition kernel.


        Args:
            hilbert: The hilbert space to sample
            rule: A `MetropolisRule` to generate random transitions from a given state as
                    well as uniform random states.
            n_chains: The number of Markov Chain to be run in parallel on a single process.
            n_sweeps: The number of exchanges that compose a single sweep.
                    If None, sweep_size is equal to the number of degrees of freedom being sampled
                    (the size of the input vector s to the machine).
            n_chains: The number of batches of the states to sample (default = 8)
            machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
            dtype: The dtype of the statees sampled (default = np.float32).
        """

        if not config.FLAGS["NETKET_EXPERIMENTAL"]:
            raise RuntimeError("""
                               Parallel Tempering samplers are under development and
                               are known not to work.
                               
                               If you want to debug it, set the environment variable
                               NETKET_EXPERIMENTAL=1
                               """)

        object.__setattr__(self, "n_replicas", n_replicas)

        super().__init__(hilbert, rule, **kwargs)

    def __post_init__(self):
        super().__post_init__()
        if (not isinstance(self.n_replicas, int) and self.n_replicas > 0
                and np.mod(self.n_replicas, 2) == 0):
            raise ValueError("n_replicas must be an even integer > 0.")

    @property
    def n_batches(self):
        return self.n_chains * self.n_replicas

    def _init_state(sampler, machine, params: PyTree,
                    key: PRNGKey) -> MetropolisPtSamplerState:
        key_state, key_rule = jax.random.split(key, 2)
        σ = jnp.zeros(
            (sampler.n_batches, sampler.hilbert.size),
            dtype=sampler.dtype,
        )
        rule_state = sampler.rule.init_state(sampler, machine, params,
                                             key_rule)

        beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas
        beta = jnp.tile(beta, (sampler.n_chains, 1))

        return MetropolisPtSamplerState(
            σ=σ,
            rng=key_state,
            rule_state=rule_state,
            n_samples=0,
            n_accepted=0,
            beta=beta,
            beta_0_index=jnp.zeros((sampler.n_chains, ), dtype=int),
            n_accepted_per_beta=jnp.zeros(
                (sampler.n_chains, sampler.n_replicas), dtype=int),
            beta_position=jnp.zeros((sampler.n_chains, )),
            beta_diffusion=jnp.zeros((sampler.n_chains, )),
            exchange_steps=0,
        )

    def _reset(sampler, machine, parameters: PyTree,
               state: MetropolisPtSamplerState):
        new_rng, rng = jax.random.split(state.rng)

        σ = sampler.rule.random_state(sampler, machine, parameters, state, rng)

        rule_state = sampler.rule.reset(sampler, machine, parameters, state)

        beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas
        beta = jnp.tile(beta, (sampler.n_chains, 1))

        return state.replace(
            σ=σ,
            rng=new_rng,
            rule_state=rule_state,
            n_samples=0,
            n_accepted=0,
            n_accepted_per_beta=jnp.zeros(
                (sampler.n_chains, sampler.n_replicas)),
            beta_position=jnp.zeros((sampler.n_chains, )),
            beta_diffusion=jnp.zeros((sampler.n_chains)),
            exchange_steps=0,
            # beta=beta,
            # beta_0_index=jnp.zeros((sampler.n_chains,), dtype=jnp.int32),
        )

    def _sample_next(sampler, machine, parameters: PyTree,
                     state: MetropolisPtSamplerState):
        new_rng, rng = jax.random.split(state.rng)
        # def cbr(data):
        #    new_rng, rng = data
        #    print("sample_next newrng:\n", new_rng,  "\nand rng:\n", rng)
        #    return new_rng
        # new_rng = hcb.call(
        #   cbr,
        #   (new_rng, rng),
        #   result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype),
        # )

        with loops.Scope() as s:
            s.key = rng
            s.σ = state.σ
            s.log_prob = sampler.machine_pow * machine(parameters,
                                                       state.σ).real
            s.beta = state.beta

            # for logging
            s.beta_0_index = state.beta_0_index
            s.n_accepted_per_beta = state.n_accepted_per_beta
            s.beta_position = state.beta_position
            s.beta_diffusion = state.beta_diffusion

            for i in s.range(sampler.n_sweeps):
                # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
                s.key, key1, key2, key3, key4 = jax.random.split(s.key, 5)

                # def cbi(data):
                #    i, beta = data
                #    print("sweep #", i, " for beta=\n", beta)
                #    return beta
                # beta = hcb.call(
                #   cbi,
                #   (i, s.beta),
                #   result_shape=jax.ShapeDtypeStruct(s.beta.shape, s.beta.dtype),
                # )
                beta = s.beta

                σp, log_prob_correction = sampler.rule.transition(
                    sampler, machine, parameters, state, key1, s.σ)
                proposal_log_prob = sampler.machine_pow * machine(
                    parameters, σp).real

                uniform = jax.random.uniform(key2, shape=(sampler.n_batches, ))
                if log_prob_correction is not None:
                    do_accept = uniform < jnp.exp(
                        beta.reshape((-1, )) *
                        (proposal_log_prob - s.log_prob + log_prob_correction))
                else:
                    do_accept = uniform < jnp.exp(
                        beta.reshape(
                            (-1, )) * (proposal_log_prob - s.log_prob))

                # do_accept must match ndim of proposal and state (which is 2)
                s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ)
                n_accepted_per_beta = s.n_accepted_per_beta + do_accept.reshape(
                    (sampler.n_chains, sampler.n_replicas))

                s.log_prob = jax.numpy.where(do_accept.reshape(-1),
                                             proposal_log_prob, s.log_prob)

                # exchange betas

                # randomly decide if every set of replicas should be swapped in even or odd order
                swap_order = jax.random.randint(
                    key3,
                    minval=0,
                    maxval=2,
                    shape=(sampler.n_chains, ),
                )  # 0 or 1
                iswap_order = jnp.mod(swap_order + 1, 2)  #  1 or 0

                # indices of even swapped elements (per-row)
                idxs = jnp.arange(0, sampler.n_replicas, 2).reshape(
                    (1, -1)) + swap_order.reshape((-1, 1))
                # indices off odd swapped elements (per-row)
                inn = (idxs + 1) % sampler.n_replicas

                # for every rows of the input, swap elements at idxs with elements at inn
                @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
                def swap_rows(beta_row, idxs, inn):
                    proposed_beta = jax.ops.index_update(
                        beta_row,
                        idxs,
                        beta_row[inn],
                        unique_indices=True,
                        indices_are_sorted=True,
                    )
                    proposed_beta = jax.ops.index_update(
                        proposed_beta,
                        inn,
                        beta_row[idxs],
                        unique_indices=True,
                        indices_are_sorted=False,
                    )
                    return proposed_beta

                proposed_beta = swap_rows(beta, idxs, inn)

                @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
                def compute_proposed_prob(prob, idxs, inn):
                    prob_rescaled = prob[idxs] + prob[inn]
                    return prob_rescaled

                # compute the probability of the swaps
                log_prob = (proposed_beta - state.beta) * s.log_prob.reshape(
                    (sampler.n_chains, sampler.n_replicas))

                prob_rescaled = jnp.exp(
                    compute_proposed_prob(log_prob, idxs, inn))

                prob_rescaled = jnp.exp(
                    compute_proposed_prob(log_prob, idxs, inn))

                uniform = jax.random.uniform(key4,
                                             shape=(sampler.n_chains,
                                                    sampler.n_replicas // 2))

                do_swap = uniform < prob_rescaled

                do_swap = jnp.dstack((do_swap, do_swap)).reshape(
                    (-1, sampler.n_replicas))  #  concat along last dimension
                # roll if swap_ordeer is odd
                @partial(jax.vmap, in_axes=(0, 0), out_axes=0)
                def fix_swap(do_swap, swap_order):
                    return jax.lax.cond(swap_order == 0, lambda x: x,
                                        lambda x: jnp.roll(x, 1), do_swap)

                do_swap = fix_swap(do_swap, swap_order)
                # jax.experimental.host_callback.id_print(state.beta)
                # jax.experimental.host_callback.id_print(proposed_beta)

                new_beta = jax.numpy.where(do_swap, proposed_beta, beta)

                def cb(data):
                    _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data
                    print("--------.---------.---------.--------")
                    print("     cur beta:\n", _bt)
                    print("proposed beta:\n", _pbt)
                    print("     new beta:\n", new_beta)
                    print("swaporder :", so)
                    print("do_swap :\n", do_swap)
                    print("log_prob;\n", log_prob)
                    print("prob_rescaled;\n", prob)
                    return new_beta

                # new_beta = hcb.call(
                #    cb,
                #    (
                #        beta,
                #        proposed_beta,
                #        new_beta,
                #        swap_order,
                #        do_swap,
                #        log_prob,
                #        prob_rescaled,
                #    ),
                #    result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype),
                # )
                # s.beta = new_beta

                swap_order = swap_order.reshape(-1)

                beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i],
                                        in_axes=(0, 0),
                                        out_axes=0)(do_swap,
                                                    state.beta_0_index)
                proposed_beta_0_index = jnp.mod(
                    state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) *
                    (-jnp.mod(state.beta_0_index, 2) * 2 + 1),
                    sampler.n_replicas,
                )

                s.beta_0_index = jnp.where(beta_0_moved, proposed_beta_0_index,
                                           s.beta_0_index)

                # swap acceptances
                swapped_n_accepted_per_beta = swap_rows(
                    n_accepted_per_beta, idxs, inn)
                s.n_accepted_per_beta = jax.numpy.where(
                    do_swap,
                    swapped_n_accepted_per_beta,
                    n_accepted_per_beta,
                )

                # Update statistics to compute diffusion coefficient of replicas
                # Total exchange steps performed
                delta = s.beta_0_index - s.beta_position
                s.beta_position = s.beta_position + delta / (
                    state.exchange_steps + i)
                delta2 = s.beta_0_index - s.beta_position
                s.beta_diffusion = s.beta_diffusion + delta * delta2

            new_state = state.replace(
                rng=new_rng,
                σ=s.σ,
                # n_accepted=s.accepted,
                n_samples=state.n_samples +
                sampler.n_sweeps * sampler.n_chains,
                beta=s.beta,
                beta_0_index=s.beta_0_index,
                beta_position=s.beta_position,
                beta_diffusion=s.beta_diffusion,
                exchange_steps=state.exchange_steps + sampler.n_sweeps,
                n_accepted_per_beta=s.n_accepted_per_beta,
            )

        offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas,
                             sampler.n_replicas)

        return new_state, new_state.σ[new_state.beta_0_index + offsets, :]

    def __repr__(sampler):
        return ("MetropolisPTSampler(" +
                "\n  hilbert = {},".format(sampler.hilbert) +
                "\n  rule = {},".format(sampler.rule) +
                "\n  n_chains = {},".format(sampler.n_chains) +
                "\n  machine_power = {},".format(sampler.machine_pow) +
                "\n  reset_chain = {},".format(sampler.reset_chain) +
                "\n  n_sweeps = {},".format(sampler.n_sweeps) +
                "\n  dtype = {},".format(sampler.dtype) + ")")

    def __str__(sampler):
        return ("MetropolisPTSampler(" + "rule = {}, ".format(sampler.rule) +
                "n_chains = {}, ".format(sampler.n_chains) +
                "machine_power = {}, ".format(sampler.machine_pow) +
                "reset_chain = {}, ".format(sampler.reset_chain) +
                "n_sweeps = {}, ".format(sampler.n_sweeps) +
                "dtype = {})".format(sampler.dtype))
Beispiel #28
0
class Point:
    x: float
    y: float
    meta: Any = struct.field(pytree_node=False)
Beispiel #29
0
class QuantizedValue:
  """State associated with quantized value."""
  quantized: chex.Array
  diagonal: chex.Array  # Diagonal (if extract_diagonal is set)
  bucket_size: chex.Array
  quantized_dtype: jnp.dtype = struct.field(
      pytree_node=False)  # Dtype for the quantized value.
  extract_diagonal: bool = struct.field(
      pytree_node=False)  # In case its centered.
  shape: Any = struct.field(pytree_node=False)  # Shape of the tensor.

  @classmethod
  def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
    if isinstance(fvalue, list) and not fvalue:
      return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
    quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
        fvalue, quantized_dtype, extract_diagonal)
    return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
                          quantized_dtype, extract_diagonal,
                          list(quantized.shape))

  # Quantization is from Lingvo JAX optimizers.
  # We extend it for int16 quantization of PSD matrices.
  @classmethod
  def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
    """Returns quantized value and the bucket."""
    if quantized_dtype == jnp.float32:
      return fvalue, [], []
    elif quantized_dtype == jnp.bfloat16:
      return fvalue.astype(jnp.bfloat16), [], []

    float_dtype = fvalue.dtype
    if quantized_dtype == jnp.int8:
      # value -128 is not used.
      num_buckets = jnp.array(127.0, dtype=float_dtype)
    elif quantized_dtype == jnp.int16:
      # value -32768 is not used.
      num_buckets = jnp.array(32767.0, dtype=float_dtype)
    else:
      raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
    # max value is mapped to num_buckets

    if extract_diagonal and fvalue.ndim != 2:
      raise ValueError(
          f'Input array {fvalue} must be 2D to work with extract_diagonal.')

    diagonal_fvalue = []
    if extract_diagonal:
      diagonal_fvalue = jnp.diag(fvalue)
      # Remove the diagonal entries.
      fvalue = fvalue - jnp.diag(diagonal_fvalue)

    # TODO(rohananil): Extend this by making use of information about the blocks
    # SM3 style which will be useful for diagonal statistics
    # We first decide the scale.
    if fvalue.ndim < 1:
      raise ValueError(
          f'Input array {fvalue} must have a strictly positive number of '
          'dimensions.')

    max_abs = jnp.max(jnp.abs(fvalue), axis=0)
    bucket_size = max_abs / num_buckets
    bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
    # To avoid divide by 0.0
    bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
                           jnp.ones_like(bs_expanded))
    ratio = fvalue / bs_nonzero
    # We use rounding to remove bias.
    quantized = jnp.round(ratio)
    return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size

  def to_float(self):
    """Returns the float value."""
    if isinstance(self.quantized, list) and not self.quantized:
      return self.quantized

    if self.quantized_dtype == jnp.float32:
      return self.quantized

    if self.quantized_dtype == jnp.bfloat16:
      return self.quantized.astype(jnp.float32)

    float_dtype = self.bucket_size.dtype
    bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
    val = self.quantized.astype(float_dtype) * bucket_size
    if self.extract_diagonal:
      val += jnp.diag(self.diagonal)
    return val
Beispiel #30
0
class Glue(struct.PyTreeNode):
    sentences: Tuple[str, Union[str, None]] = struct.field(pytree_node=False)
    num_labels: int = struct.field(pytree_node=False)
    is_regression: bool = struct.field(pytree_node=False)