示例#1
0
def test_batching_wrapper():
    from netket.utils import wrap_to_support_scalar

    def applyfun(pars, x, mutable=False):
        # this assert fails if the wrapper is not working
        assert x.ndim > 1
        if not mutable:
            return x.sum(axis=-1)
        else:
            return (x.sum(axis=-1), {})

    # check same hash
    assert hash(wrap_to_support_scalar(applyfun)) == hash(
        wrap_to_support_scalar(applyfun)
    )

    afun = wrap_to_support_scalar(applyfun)

    x = jnp.ones(5)
    xb = jnp.ones((1, 5))

    # no mutable state
    res = afun(None, x)
    assert res.shape == ()
    assert res == jnp.sum(x, axis=-1)
    res = afun(None, xb)
    assert res.shape == (1,)
    assert res == jnp.sum(x, axis=-1)

    # mutable state
    res = afun(None, x, mutable=True)[0]
    assert res.shape == ()
    assert res == jnp.sum(x, axis=-1)
    res = afun(None, xb, mutable=True)[0]
    assert res.shape == (1,)
    assert res == jnp.sum(x, axis=-1)
示例#2
0
    def __init__(
        self,
        sampler: Sampler,
        model=None,
        *,
        n_samples: int = None,
        n_samples_per_rank: Optional[int] = None,
        n_discard: Optional[int] = None,  # deprecated
        n_discard_per_chain: Optional[int] = None,
        chunk_size: Optional[int] = None,
        variables: Optional[PyTree] = None,
        init_fun: NNInitFunc = None,
        apply_fun: Callable = None,
        sample_fun: Callable = None,
        seed: Optional[SeedT] = None,
        sampler_seed: Optional[SeedT] = None,
        mutable: bool = False,
        training_kwargs: Dict = {},
    ):
        """
        Constructs the MCState.

        Args:
            sampler: The sampler
            model: (Optional) The model. If not provided, you must provide init_fun and apply_fun.
            n_samples: the total number of samples across chains and processes when sampling (default=1000).
            n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be
                specified together with n_samples (default=None).
            n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler,
                and n_samples/10 for approximate sampler).
            parameters: Optional PyTree of weights from which to start.
            seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.
            sampler_seed: rng seed used to initialise the sampler. Defaults to a random one.
            mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change
                during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation
                (default=False)
            init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to
                initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has
                a non-standard init method.
            variables: Optional initial value for the variables (parameters and model state) of the model.
            apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to
                `model.apply(variables, σ)`. specify only if your network has a non-standard apply method.
            sample_fun: Optional function used to sample the state, if it is not the same as `apply_fun`.
            training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training.
                Useful for example when you have a batchnorm layer that constructs the average/mean only during training.
            n_discard: DEPRECATED. Please use `n_discard_per_chain` which has the same behaviour.
        """
        super().__init__(sampler.hilbert)

        # Init type 1: pass in a model
        if model is not None:
            # extract init and apply functions
            # Wrap it in an HashablePartial because if two instances of the same model are provided,
            # model.apply and model2.apply will be different methods forcing recompilation, but
            # model and model2 will have the same hash.
            _, model = maybe_wrap_module(model)

            self._model = model

            self._init_fun = nkjax.HashablePartial(
                lambda model, *args, **kwargs: model.init(*args, **kwargs), model
            )
            self._apply_fun = wrap_to_support_scalar(
                nkjax.HashablePartial(
                    lambda model, pars, x, **kwargs: model.apply(pars, x, **kwargs),
                    model,
                )
            )

        elif apply_fun is not None:
            self._apply_fun = wrap_to_support_scalar(apply_fun)

            if init_fun is not None:
                self._init_fun = init_fun
            elif variables is None:
                raise ValueError(
                    "If you don't provide variables, you must pass a valid init_fun."
                )

            self._model = wrap_afun(apply_fun)

        else:
            raise ValueError(
                "Must either pass the model or apply_fun, otherwise how do you think we"
                "gonna evaluate the model?"
            )

        # default argument for n_samples/n_samples_per_rank
        if n_samples is None and n_samples_per_rank is None:
            # get the first multiple of sampler.n_chains above 1000 to avoid
            # printing a warning on construction
            n_samples = int(np.ceil(1000 / sampler.n_chains) * sampler.n_chains)
        elif n_samples is not None and n_samples_per_rank is not None:
            raise ValueError(
                "Only one argument between `n_samples` and `n_samples_per_rank`"
                "can be specified at the same time."
            )

        if n_discard is not None and n_discard_per_chain is not None:
            raise ValueError(
                "`n_discard` has been renamed to `n_discard_per_chain` and deprecated."
                "Specify only `n_discard_per_chain`."
            )
        elif n_discard is not None:
            warn_deprecation(
                "`n_discard` has been renamed to `n_discard_per_chain` and deprecated."
                "Please update your code to use `n_discard_per_chain`."
            )
            n_discard_per_chain = n_discard

        if sample_fun is not None:
            self._sample_fun = sample_fun
        else:
            self._sample_fun = self._apply_fun

        self.mutable = mutable
        self.training_kwargs = flax.core.freeze(training_kwargs)

        if variables is not None:
            self.variables = variables
        else:
            self.init(seed, dtype=sampler.dtype)

        if sampler_seed is None and seed is not None:
            key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2)
            sampler_seed = key2

        self._sampler_seed = sampler_seed
        self.sampler = sampler

        if n_samples is not None:
            self.n_samples = n_samples
        else:
            self.n_samples_per_rank = n_samples_per_rank

        self.n_discard_per_chain = n_discard_per_chain

        self.chunk_size = chunk_size
示例#3
0
    def __init__(
        self,
        hilbert: AbstractHilbert,
        model=None,
        *,
        variables: Optional[PyTree] = None,
        init_fun: NNInitFunc = None,
        apply_fun: Callable = None,
        seed: Optional[SeedT] = None,
        mutable: bool = False,
        training_kwargs: Dict = {},
        dtype=float,
    ):
        """
        Constructs the ExactState.

        Args:
            hilbert: The Hilbert space
            model: (Optional) The model. If not provided, you must provide init_fun and apply_fun.
            parameters: Optional PyTree of weights from which to start.
            seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.
            mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change
                during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation
                (default=False)
            init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to
                initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has
                a non-standard init method.
            variables: Optional initial value for the variables (parameters and model state) of the model.
            apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to
                `model.apply(variables, σ)`. specify only if your network has a non-standard apply method.
            training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training.
                Useful for example when you have a batchnorm layer that constructs the average/mean only during training.
        """
        super().__init__(hilbert)

        # Init type 1: pass in a model
        if model is not None:
            # extract init and apply functions
            # Wrap it in an HashablePartial because if two instances of the same model are provided,
            # model.apply and model2.apply will be different methods forcing recompilation, but
            # model and model2 will have the same hash.
            _, model = maybe_wrap_module(model)

            self._model = model

            self._init_fun = nkjax.HashablePartial(
                lambda model, *args, **kwargs: model.init(*args, **kwargs),
                model)
            self._apply_fun = wrap_to_support_scalar(
                nkjax.HashablePartial(
                    lambda model, *args, **kwargs: model.apply(
                        *args, **kwargs), model))

        elif apply_fun is not None:
            self._apply_fun = wrap_to_support_scalar(apply_fun)

            if init_fun is not None:
                self._init_fun = init_fun
            elif variables is None:
                raise ValueError(
                    "If you don't provide variables, you must pass a valid init_fun."
                )

            self._model = wrap_afun(apply_fun)

        else:
            raise ValueError(
                "Must either pass the model or apply_fun, otherwise how do you think we"
                "gonna evaluate the model?")

        self.mutable = mutable
        self.training_kwargs = flax.core.freeze(training_kwargs)

        if variables is not None:
            self.variables = variables
        else:
            self.init(seed, dtype=dtype)

        self._states = None
        """
        Caches the output of `self._all_states()`.
        """

        self._array = None
        """
        Caches the output of `self.to_array()`.
        """

        self._pdf = None
        """