예제 #1
0
파일: base.py 프로젝트: tobiaswiener/netket
    def sample(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        *,
        state: Optional[SamplerState] = None,
        chain_length: int = 1,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples `chain_length` batches of samples along the chains.

        Arguments:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If not specified, then initialize and reset it.
            chain_length: The length of the chains (default = 1).

        Returns:
            σ: The generated batches of samples.
            state: The new state of the sampler.
        """
        if state is None:
            state = sampler.reset(machine, parameters)

        return sampler._sample_chain(wrap_afun(machine), parameters, state,
                                     chain_length)
예제 #2
0
파일: base.py 프로젝트: tobiaswiener/netket
    def samples(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        *,
        state: Optional[SamplerState] = None,
        chain_length: int = 1,
    ) -> Iterator[jnp.ndarray]:
        """
        Returns a generator sampling `chain_length` batches of samples along the chains.

        Arguments:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If not specified, then initialize and reset it.
            chain_length: The length of the chains (default = 1).
        """
        if state is None:
            state = sampler.reset(machine, parameters)

        machine = wrap_afun(machine)

        for i in range(chain_length):
            samples, state = sampler._sample_chain(machine, parameters, state,
                                                   1)
            yield samples[0, :, :]
예제 #3
0
파일: base.py 프로젝트: tobiaswiener/netket
    def init_state(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        seed: Optional[SeedT] = None,
    ) -> SamplerState:
        """
        Creates the structure holding the state of the sampler.

        If you want reproducible samples, you should specify `seed`, otherwise the state
        will be initialised randomly.

        If running across several MPI processes, all `sampler_state`s are guaranteed to be
        in a different (but deterministic) state.
        This is achieved by first reducing (summing) the seed provided to every MPI rank,
        then generating `n_rank` seeds starting from the reduced one, and every rank is
        initialized with one of those seeds.

        The resulting state is guaranteed to be a frozen Python dataclass (in particular,
        a Flax dataclass), and it can be serialized using Flax serialization methods.

        Args:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used.

        Returns:
            The structure holding the state of the sampler. In general you should not expect
            it to be in a valid state, and should reset it before use.
        """
        key = nkjax.PRNGKey(seed)
        key = nkjax.mpi_split(key)

        return sampler._init_state(wrap_afun(machine), parameters, key)
예제 #4
0
    def sample_next(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> Tuple[SamplerState, jnp.ndarray]:
        """
        Samples the next state in the Markov chain.

        Args:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If not specified, then initialize and reset it.

        Returns:
            state: The new state of the sampler.
            σ: The next batch of samples.

        Note:
            The return order is inverted wrt `sample` because when called inside of
            a scan function the first returned argument should be the state.
        """
        if state is None:
            state = sampler.reset(machine, parameters)

        return sampler._sample_next(wrap_afun(machine), parameters, state)
예제 #5
0
파일: base.py 프로젝트: chrisrothUT/netket
    def sample_next(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples the next state in the markov chain.

        Args:
            machine: a Flax module or callable apply function with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If it's not provided, it will be constructed
                by calling :code:`sampler.reset(machine, parameters)` with a random seed.

        Returns:
            state: The new state of the sampler
            σ: The next batch of samples.
        """
        # Note: the return order is inverted wrt `.sample` because when called inside of
        # a scan function the first returned argument should be the state.

        if state is None:
            state = sampler_state(sampler, machine, parameters)

        return sampler._sample_next(wrap_afun(machine), parameters, state)
예제 #6
0
파일: base.py 프로젝트: chrisrothUT/netket
    def reset(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> SamplerState:
        """
        Resets the state of the sampler. To be used every time the parameters are changed.

        Args:
            machine: a Flax module or callable with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If it's not provided, it will be constructed
                by calling :code:`sampler.init_state(machine, parameters)` with a random seed.

        Returns:
            A valid sampler state.
        """
        if state is None:
            state = sampler_state(sampler, machine, parameters)

        return sampler._reset(wrap_afun(machine), parameters, state)
예제 #7
0
    def __init__(
        self,
        sampler: Sampler,
        model=None,
        *,
        n_samples: int = None,
        n_samples_per_rank: Optional[int] = None,
        n_discard: Optional[int] = None,  # deprecated
        n_discard_per_chain: Optional[int] = None,
        variables: Optional[PyTree] = None,
        init_fun: NNInitFunc = None,
        apply_fun: Callable = None,
        sample_fun: Callable = None,
        seed: Optional[SeedT] = None,
        sampler_seed: Optional[SeedT] = None,
        mutable: bool = False,
        training_kwargs: Dict = {},
    ):
        """
        Constructs the MCState.

        Args:
            sampler: The sampler
            model: (Optional) The model. If not provided, you must provide init_fun and apply_fun.
            n_samples: the total number of samples across chains and processes when sampling (default=1000).
            n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be
                specified together with n_samples (default=None).
            n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler,
                and n_samples/10 for approximate sampler).
            parameters: Optional PyTree of weights from which to start.
            seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.
            sampler_seed: rng seed used to initialise the sampler. Defaults to a random one.
            mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change
                during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation
                (default=False)
            init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to
                initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has
                a non-standard init method.
            variables: Optional initial value for the variables (parameters and model state) of the model.
            apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to
                `model.apply(variables, σ)`. specify only if your network has a non-standard apply method.
            sample_fun: Optional function used to sample the state, if it is not the same as `apply_fun`.
            training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training.
                Useful for example when you have a batchnorm layer that constructs the average/mean only during training.
            n_discard: DEPRECATED. Please use `n_discard_per_chain` which has the same behaviour.
        """
        super().__init__(sampler.hilbert)

        # Init type 1: pass in a model
        if model is not None:
            # extract init and apply functions
            # Wrap it in an HashablePartial because if two instances of the same model are provided,
            # model.apply and model2.apply will be different methods forcing recompilation, but
            # model and model2 will have the same hash.
            _, model = maybe_wrap_module(model)

            self._model = model

            self._init_fun = nkjax.HashablePartial(
                lambda model, *args, **kwargs: model.init(*args, **kwargs),
                model)
            self._apply_fun = nkjax.HashablePartial(
                lambda model, *args, **kwargs: model.apply(*args, **kwargs),
                model)

        elif apply_fun is not None:
            self._apply_fun = apply_fun

            if init_fun is not None:
                self._init_fun = init_fun
            elif variables is None:
                raise ValueError(
                    "If you don't provide variables, you must pass a valid init_fun."
                )

            self._model = wrap_afun(apply_fun)

        else:
            raise ValueError(
                "Must either pass the model or apply_fun, otherwise how do you think we"
                "gonna evaluate the model?")

        # default argument for n_samples/n_samples_per_rank
        if n_samples is None and n_samples_per_rank is None:
            n_samples = 1000
        elif n_samples is not None and n_samples_per_rank is not None:
            raise ValueError(
                "Only one argument between `n_samples` and `n_samples_per_rank`"
                "can be specified at the same time.")

        if n_discard is not None and n_discard_per_chain is not None:
            raise ValueError(
                "`n_discard` has been renamed to `n_discard_per_chain` and deprecated."
                "Specify only `n_discard_per_chain`.")
        elif n_discard is not None:
            warn_deprecation(
                "`n_discard` has been renamed to `n_discard_per_chain` and deprecated."
                "Please update your code to use `n_discard_per_chain`.")
            n_discard_per_chain = n_discard

        if sample_fun is not None:
            self._sample_fun = sample_fun
        else:
            self._sample_fun = self._apply_fun

        self.mutable = mutable
        self.training_kwargs = flax.core.freeze(training_kwargs)

        if variables is not None:
            self.variables = variables
        else:
            self.init(seed, dtype=sampler.dtype)

        if sampler_seed is None and seed is not None:
            key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2)
            sampler_seed = key2

        self._sampler_seed = sampler_seed
        self.sampler = sampler

        if n_samples is not None:
            self.n_samples = n_samples
        else:
            self.n_samples_per_rank = n_samples_per_rank

        self.n_discard_per_chain = n_discard_per_chain
예제 #8
0
    def __init__(
        self,
        hilbert: AbstractHilbert,
        model=None,
        *,
        variables: Optional[PyTree] = None,
        init_fun: NNInitFunc = None,
        apply_fun: Callable = None,
        seed: Optional[SeedT] = None,
        mutable: bool = False,
        training_kwargs: Dict = {},
        dtype=float,
    ):
        """
        Constructs the ExactState.

        Args:
            hilbert: The Hilbert space
            model: (Optional) The model. If not provided, you must provide init_fun and apply_fun.
            parameters: Optional PyTree of weights from which to start.
            seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.
            mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change
                during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation
                (default=False)
            init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to
                initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has
                a non-standard init method.
            variables: Optional initial value for the variables (parameters and model state) of the model.
            apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to
                `model.apply(variables, σ)`. specify only if your network has a non-standard apply method.
            training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training.
                Useful for example when you have a batchnorm layer that constructs the average/mean only during training.
        """
        super().__init__(hilbert)

        # Init type 1: pass in a model
        if model is not None:
            # extract init and apply functions
            # Wrap it in an HashablePartial because if two instances of the same model are provided,
            # model.apply and model2.apply will be different methods forcing recompilation, but
            # model and model2 will have the same hash.
            _, model = maybe_wrap_module(model)

            self._model = model

            self._init_fun = nkjax.HashablePartial(
                lambda model, *args, **kwargs: model.init(*args, **kwargs),
                model)
            self._apply_fun = wrap_to_support_scalar(
                nkjax.HashablePartial(
                    lambda model, *args, **kwargs: model.apply(
                        *args, **kwargs), model))

        elif apply_fun is not None:
            self._apply_fun = wrap_to_support_scalar(apply_fun)

            if init_fun is not None:
                self._init_fun = init_fun
            elif variables is None:
                raise ValueError(
                    "If you don't provide variables, you must pass a valid init_fun."
                )

            self._model = wrap_afun(apply_fun)

        else:
            raise ValueError(
                "Must either pass the model or apply_fun, otherwise how do you think we"
                "gonna evaluate the model?")

        self.mutable = mutable
        self.training_kwargs = flax.core.freeze(training_kwargs)

        if variables is not None:
            self.variables = variables
        else:
            self.init(seed, dtype=dtype)

        self._states = None
        """
        Caches the output of `self._all_states()`.
        """

        self._array = None
        """
        Caches the output of `self.to_array()`.
        """

        self._pdf = None
        """