Exemplo n.º 1
0
    def _set_prebuilt_dataset(self, main, remaining, validation_main,
                              validation_remaining):
        """Set preconstructed dataset iterators

        Parameters
        ----------
        main : list of tf.data.Dataset().as_numpy_iterators()
            The simulations generated at the fiducial model parameter values
            used for calculating the covariance of network outputs and their
            derivatives with respect to the physical model parameters (for
            fitting). These are served ``n_per_device`` at at time as a numpy
            iterator from a TensorFlow dataset.
        remaining : list of tf.data.Dataset().as_numpy_iterators()
            The ``n_s - n_d`` simulations generated at the fiducial model
            parameter values used for calculating the covariance of network
            outputs with a derivative counterpart (for fitting). These are
            served ``n_per_device`` at at time as a numpy iterator from a
            TensorFlow dataset.
        validation_main : list of tf.data.Dataset().as_numpy_iterators()
            The simulations generated at the fiducial model parameter values
            used for calculating the covariance of network outputs and their
            derivatives with respect to the physical model parameters
            (for validation). These are served ``n_per_device`` at at time as a
            numpy iterator from a TensorFlow dataset.
        validation_remaining : list of tf.data.Dataset().as_numpy_iterators()
            The ``n_s - n_d`` simulations generated at the fiducial model
            parameter values used for calculating the covariance of network
            outputs with a derivative counterpart (for validation). Served
            ``n_per_device`` at time as a numpy iterator from a TensorFlow
            dataset.

        Raises
        ------
        ValueError
            if main or remaining are None
        ValueError
            if length of any input list is not equal to number of devices
        TypeError
            if any input is not a list
        """
        self.main = _check_type(main, list, "main", shape=self.n_devices)
        self.remaining = _check_type(remaining,
                                     list,
                                     "remaining",
                                     shape=self.n_devices)
        if ((validation_main is not None)
                and (validation_remaining is not None)):
            self.validation_main = _check_type(validation_main,
                                               list,
                                               "validation_main",
                                               shape=self.n_devices)
            self.validation_remaining = _check_type(validation_remaining,
                                                    list,
                                                    "validation_remaining",
                                                    shape=self.n_devices)
            self.validate = True
Exemplo n.º 2
0
    def _set_devices(self, devices, n_per_device):
        """Checks that devices exist and that reshaping onto devices can occur

        Due to the aggregation then balanced splits must be made between the
        different devices and so these are checked.

        Parameters
        ----------
        devices: list
            A list of the available jax devices (from ``jax.devices()``)
        n_per_device: int
            Number of simulations to handle at once, this should be as large as
            possible without letting the memory overflow for the best
            performance

        Raises
        ------
        ValueError
            If ``devices`` or ``n_per_device`` are None
        ValueError
            If balanced splitting cannot be done
        TypeError
            If ``devices`` is not a list and if ``n_per_device`` is not an int
        """
        self.devices = _check_devices(devices)
        self.n_devices = len(self.devices)
        self.n_per_device = _check_type(n_per_device, int, "n_per_device")
        if self.n_s == self.n_d:
            _check_splitting(self.n_s, "n_s and n_d", self.n_devices,
                             self.n_per_device)
        else:
            _check_splitting(self.n_s, "n_s", self.n_devices,
                             self.n_per_device)
            _check_splitting(self.n_d, "n_d", self.n_devices,
                             self.n_per_device)
    def _set_dataset(self, prefetch=None, cache=None):
        """Overwritten function to prevent building dataset, does list check

        Raises
        ------
        ValueError
            if fiducial or derivative are None
        ValueError
            if any dataset has wrong shape
        TypeError
            if any dataset has wrong type
        """
        _check_type(self.fiducial, list, "fiducial", shape=self.n_devices)
        _check_type(self.derivative, list, "derivative", shape=self.n_devices)
        if self.validate:
            _check_type(self.validation_fiducial, list, "validation_fiducial",
                        shape=self.n_devices)
            _check_type(self.validation_derivative, list,
                        "validation_derivative", shape=self.n_devices)
Exemplo n.º 4
0
    def _setup_progress_bar(self, print_rate, max_iterations):
        """Construct progress bar

        Parameters
        ----------
        print_rate : int or None
            The rate at which the progress bar is updated (no bar if None)
        max_iterations : int
            The maximum number of iterations, used to setup bar upper limit

        Returns
        -------
        progress bar or None:
            The TQDM progress bar object
        int or None:
            The print rate (after checking for int or None)
        int or None:
            The difference between the max_iterations and the print rate

        Raises
        ------
        TypeError:
            If ``print_rate`` is not an integer
        """
        print_rate = _check_type(print_rate,
                                 int,
                                 "print_rate",
                                 allow_None=True)
        if print_rate is not None:
            if max_iterations < 10000:
                pbar = tqdm.tqdm(total=max_iterations)
            else:
                pbar = tqdm.tqdm()
            remainder = max_iterations % print_rate
            return pbar, print_rate, remainder
        else:
            return None, None, None
Exemplo n.º 5
0
    def fit(self,
            λ,
            ϵ,
            rng=None,
            patience=100,
            min_iterations=100,
            max_iterations=int(1e5),
            print_rate=None,
            best=True):
        """Fitting routine for the IMNN

        Parameters
        ----------
        λ : float
            Coupling strength of the regularisation
        ϵ : float
            Closeness criterion describing how close to the 1 the determinant
            of the covariance (and inverse covariance) of the network outputs
            is desired to be
        rng : int(2,) or None, default=None
            Stateless random number generator
        patience : int, default=10
            Number of iterations where there is no increase in the value of the
            determinant of the Fisher information matrix, used for early
            stopping
        min_iterations : int, default=100
            Number of iterations that should be run before considering early
            stopping using the patience counter
        max_iterations : int, default=int(1e5)
            Maximum number of iterations to run the fitting procedure for
        print_rate : int or None, default=None,
            Number of iterations before updating the progress bar whilst
            fitting. There is a performance hit from updating the progress bar
            more often and there is a large performance hit from using the
            progress bar at all. (Possible ``RET_CHECK`` failure if
            ``print_rate`` is not ``None`` when using GPUs).
            For this reason it is set to None as default
        best : bool, default=True
            Whether to set the network parameter attribute ``self.w`` to the
            parameter values that obtained the maximum determinant of
            the Fisher information matrix or the parameter values at the final
            iteration of fitting

        Example
        -------

        We are going to summarise the mean and variance of some random Gaussian
        noise with 10 data points per example using an AggregatedSimulatorIMNN.
        In this case we are going to generate the simulations on-the-fly with a
        simulator written in jax (from the examples directory). These
        simulations will be generated on-the-fly and passed through the network
        on each of the GPUs in ``jax.devices("gpu")`` and we will make 100
        simulations on each device at a time. The main computation will be done
        on the CPU. We will use 1000 simulations to estimate the covariance of
        the network outputs and the derivative of the mean of the network
        outputs with respect to the model parameters (Gaussian mean and
        variance) and generate the simulations at a fiducial μ=0 and Σ=1. The
        network will be a stax model with hidden layers of ``[128, 128, 128]``
        activated with leaky relu and outputting 2 summaries. Optimisation will
        be via Adam with a step size of ``1e-3``. Rather arbitrarily we'll set
        the regularisation strength and covariance identity constraint to λ=10
        and ϵ=0.1 (these are relatively unimportant for such an easy model).

        .. code-block:: python

            import jax
            import jax.numpy as np
            from jax.experimental import stax, optimizers
            from imnn import AggregatedSimulatorIMNN

            rng = jax.random.PRNGKey(0)

            n_s = 1000
            n_d = 1000
            n_params = 2
            n_summaries = 2
            input_shape = (10,)
            θ_fid = np.array([0., 1.])

            def simulator(rng, θ):
                return θ[0] + jax.random.normal(
                    rng, shape=input_shape) * np.sqrt(θ[1])

            model = stax.serial(
                stax.Dense(128),
                stax.LeakyRelu,
                stax.Dense(128),
                stax.LeakyRelu,
                stax.Dense(128),
                stax.LeakyRelu,
                stax.Dense(n_summaries))
            optimiser = optimizers.adam(step_size=1e-3)

            λ = 10.
            ϵ = 0.1

            model_key, fit_key = jax.random.split(rng)

            host = jax.devices("cpu")[0]
            devices = jax.devices("gpu")

            n_per_device = 100

            imnn = AggregatedSimulatorIMNN(
                n_s=n_s, n_d=n_d, n_params=n_params, n_summaries=n_summaries,
                input_shape=input_shape, θ_fid=θ_fid, model=model,
                optimiser=optimiser, key_or_state=model_key,
                simulator=simulator, host=host, devices=devices,
                n_per_device=n_per_device)

            imnn.fit(λ, ϵ, rng=fit_key, min_iterations=1000, patience=250,
                     print_rate=None)


        Notes
        -----
        A minimum number of interations should be be run before stopping based
        on a maximum determinant of the Fisher information achieved since the
        loss function has dual objectives. Since the determinant of the
        covariance of the network outputs is forced to 1 quickly, this can be
        at the detriment to the value of the determinant of the Fisher
        information matrix early in the fitting procedure. For this reason
        starting early stopping after the covariance has converged is advised.
        This is not currently implemented but could be considered in the
        future.

        The best fit network parameter values are probably not the most
        representative set of parameters when simulating on-the-fly since there
        is a high chance of a statistically overly-informative set of data
        being generated. Instead, if using
        :func:`~imnn.AggregatedSimulatorIMNN.fit()` consider using
        ``best=False`` which sets ``self.w=self.final_w`` which are the network
        parameter values obtained in the last iteration. Also consider using a
        larger ``patience`` value if using :func:`~imnn.SimulatorIMNN.fit()`
        to overcome the fact that a flukish high value for the determinant
        might have been obtained due to the realisation of the dataset.

        Raises
        ------
        TypeError
            If any input has the wrong type
        ValueError
            If any input (except ``rng``) are ``None``
        ValueError
            If ``rng`` has the wrong shape
        ValueError
            If ``rng`` is ``None`` but simulating on-the-fly

        Methods
        -------
        get_keys_and_params:
            Jitted collection of parameters and random numbers
        calculate_loss:
            Returns the jitted gradient of the loss function wrt summaries
        validation_loss:
            Jitted loss and auxillary statistics from validation set

        Todo
        ----
        - ``rng`` is currently only used for on-the-fly simulation but could
          easily be updated to allow for stochastic models
        - Automatic detection of convergence based on value ``r`` when early
          stopping can be started
        """
        @jax.jit
        def get_keys_and_params(rng, state):
            """Jitted collection of parameters and random numbers

            Parameters
            ----------
            rng : int(2,) or None, default=None
                Stateless random number generator
            state : :obj:state
                The optimiser state used for updating the network parameters
                and optimisation algorithm

            Returns
            -------
            int(2,) or None, default=None:
                Stateless random number generator
            int(2,) or None, default=None:
                Stateless random number generator for training
            int(2,) or None, default=None:
                Stateless random number generator for validation
            list:
                Network parameter values
            """
            rng, training_key, validation_key = self._get_fitting_keys(rng)
            w = self._get_parameters(state)
            return rng, training_key, validation_key, w

        @jax.jit
        @partial(jax.grad, argnums=(0, 1), has_aux=True)
        def calculate_loss(summaries, summary_derivatives):
            """Returns the jitted gradient of the loss function wrt summaries

            Used to calculate the gradient of the loss function wrt summaries
            and derivatives of the summaries with respect to model parameters
            which will be used to calculate the aggregated gradient of the
            Fisher information with respect to the network parameters via the
            chain rule.

            Parameters
            ----------
            summaries : float(n_s, n_summaries)
                The network outputs
            summary_derivatives : float(n_d, n_summaries, n_params)
                The derivative of the network outputs wrt the model parameters

            Returns
            -------
            tuple:
                Gradient of the loss function with respect to network outputs
                and their derivatives with respect to physical model parameters
            tuple:
                Fitting statistics calculated on a single iteration
                    - **F** *(float(n_params, n_params))* -- Fisher information
                      matrix
                    - **C** *(float(n_summaries, n_summaries))* -- covariance
                      of network outputs
                    - **invC** *(float(n_summaries, n_summaries))* -- inverse
                      covariance of network outputs
                    - **Λ2** *(float)* -- covariance regularisation
                    - **r** *(float)* -- regularisation coupling strength
            """
            return self._calculate_loss(summaries, summary_derivatives, λ, α)

        @jax.jit
        def validation_loss(summaries, derivatives):
            """Jitted loss and auxillary statistics from validation set

            Parameters
            ----------
            summaries : float(n_s, n_summaries)
                The network outputs
            summary_derivatives : float(n_d, n_summaries, n_params)
                The derivative of the network outputs wrt the model parameters

            Returns
            -------
            tuple:
                Fitting statistics calculated on a single validation iteration
                    - **F** *(float(n_params, n_params))* -- Fisher information
                      matrix
                    - **C** *(float(n_summaries, n_summaries))* -- covariance
                      of network outputs
                    - **invC** *(float(n_summaries, n_summaries))* -- inverse
                      covariance of network outputs
                    - **Λ2** *(float)* -- covariance regularisation
                    - **r** *(float)* -- regularisation coupling strength
            """
            F, C, invC, *_ = self._calculate_F_statistics(
                summaries, derivatives)
            _Λ2 = self._get_regularisation(C, invC)
            _r = self._get_regularisation_strength(_Λ2, λ, α)
            return (F, C, invC, _Λ2, _r)

        λ = _check_type(λ, float, "λ")
        ϵ = _check_type(ϵ, float, "ϵ")
        α = self.get_α(λ, ϵ)
        patience = _check_type(patience, int, "patience")
        min_iterations = _check_type(min_iterations, int, "min_iterations")
        max_iterations = _check_type(max_iterations, int, "max_iterations")
        best = _check_boolean(best, "best")
        if self.simulate and (rng is None):
            raise ValueError("`rng` is necessary when simulating.")
        rng = _check_input(rng, (2, ), "rng", allow_None=True)
        max_detF, best_w, detF, detC, detinvC, Λ2, r, counter, \
            patience_counter, state, rng = self._set_inputs(
                rng, max_iterations)
        pbar, print_rate, remainder = self._setup_progress_bar(
            print_rate, max_iterations)
        while self._fit_cond((max_detF, best_w, detF, detC, detinvC, Λ2, r,
                              counter, patience_counter, state, rng),
                             patience=patience,
                             max_iterations=max_iterations):
            rng, training_key, validation_key, w = get_keys_and_params(
                rng, state)
            summaries, summary_derivatives = self.get_summaries(
                w=w, key=training_key)
            dΛ_dx, results = calculate_loss(summaries, summary_derivatives)
            grad = self.get_gradient(dΛ_dx, w, key=training_key)
            state = self._update(counter, grad, state)
            w = self._get_parameters(state)
            detF, detC, detinvC, Λ2, r = self._update_history(
                results, (detF, detC, detinvC, Λ2, r), counter, 0)
            if self.validate:
                summaries, summary_derivatives = self.get_summaries(
                    w=w, key=training_key, validate=True)
                results = validation_loss(summaries, summary_derivatives)
                detF, detC, detinvC, Λ2, r = self._update_history(
                    results, (detF, detC, detinvC, Λ2, r), counter, 1)
            _detF = np.linalg.det(results[0])
            patience_counter, counter, _, max_detF, __, best_w = \
                jax.lax.cond(
                    np.greater(_detF, max_detF),
                    self._update_loop_vars,
                    lambda inputs: self._check_loop_vars(
                        inputs, min_iterations),
                    (patience_counter, counter, _detF, max_detF, w, best_w))
            self._update_progress_bar(pbar, counter, patience_counter,
                                      max_detF, detF[counter], detC[counter],
                                      detinvC[counter], Λ2[counter],
                                      r[counter], print_rate, max_iterations,
                                      remainder)
            counter += 1
        self._update_progress_bar(pbar,
                                  counter,
                                  patience_counter,
                                  max_detF,
                                  detF[counter - 1],
                                  detC[counter - 1],
                                  detinvC[counter - 1],
                                  Λ2[counter - 1],
                                  r[counter - 1],
                                  print_rate,
                                  max_iterations,
                                  remainder,
                                  close=True)
        self.history["max_detF"] = max_detF
        self.best_w = best_w
        self._set_history((detF[:counter], detC[:counter], detinvC[:counter],
                           Λ2[:counter], r[:counter]))
        self.state = state
        self.final_w = self._get_parameters(self.state)
        if best:
            w = self.best_w
        else:
            w = self.final_w
        self.set_F_statistics(w, key=rng)
    def _set_dataset(self, prefetch, cache):
        """ Transforms the data into lists of tensorflow dataset iterators

        Parameters
        ----------
        prefetch : tf.data.AUTOTUNE or int or None
            How many simulation to prefetch in the tensorflow dataset
        cache : bool
            Whether to cache simulations in the tensorflow datasets

        Raises
        ------
        ValueError
            If ``cache`` and/or ``prefetch`` is None
        TypeError
            If ``cache`` and/or ``prefetch`` is wrong type
        """
        cache = _check_boolean(cache, "cache")
        prefetch = _check_type(prefetch, int, "prefetch", allow_None=True)
        self.fiducial = self.fiducial.reshape(self.fiducial_batch_shape +
                                              self.input_shape)
        self.fiducial = [
            tf.data.Dataset.from_tensor_slices(fiducial)
            for fiducial in self.fiducial
        ]

        self.derivative = self.derivative.reshape(self.derivative_batch_shape +
                                                  self.input_shape)
        self.derivative = [
            tf.data.Dataset.from_tensor_slices(derivative)
            for derivative in self.derivative
        ]

        if cache:
            self.fiducial = [fiducial.cache() for fiducial in self.fiducial]
            self.derivative = [
                derivative.cache() for derivative in self.derivative
            ]
        if prefetch is not None:
            self.fiducial = [
                fiducial.prefetch(prefetch) for fiducial in self.fiducial
            ]
            self.derivative = [
                derivative.prefetch(prefetch) for derivative in self.derivative
            ]

        self.fiducial = [
            fiducial.repeat().as_numpy_iterator() for fiducial in self.fiducial
        ]
        self.derivative = [
            derivative.repeat().as_numpy_iterator()
            for derivative in self.derivative
        ]

        if self.validate:
            self.validation_fiducial = self.validation_fiducial.reshape(
                self.fiducial_batch_shape + self.input_shape)
            self.validation_fiducial = [
                tf.data.Dataset.from_tensor_slices(fiducial)
                for fiducial in self.validation_fiducial
            ]

            self.validation_derivative = self.validation_derivative.reshape(
                self.derivative_batch_shape + self.input_shape)
            self.validation_derivative = [
                tf.data.Dataset.from_tensor_slices(derivative)
                for derivative in self.validation_derivative
            ]

            if cache:
                self.validation_fiducial = [
                    fiducial.cache() for fiducial in self.validation_fiducial
                ]
                self.validation_derivative = [
                    derivative.cache()
                    for derivative in self.validation_derivative
                ]
            if prefetch is not None:
                self.validation_fiducial = [
                    fiducial.prefetch(prefetch)
                    for fiducial in self.validation_fiducial
                ]
                self.validation_derivative = [
                    derivative.prefetch(prefetch)
                    for derivative in self.validation_derivative
                ]

            self.validation_fiducial = [
                fiducial.repeat().as_numpy_iterator()
                for fiducial in self.validation_fiducial
            ]
            self.validation_derivative = [
                derivative.repeat().as_numpy_iterator()
                for derivative in self.validation_derivative
            ]
Exemplo n.º 7
0
    def _set_dataset(self, prefetch, cache):
        """ Collates data into loopable tensorflow dataset iterations

        Parameters
        ----------
        prefetch : tf.data.AUTOTUNE or int or None
            How many simulation to prefetch in the tensorflow dataset
        cache : bool
            Whether to cache simulations in the tensorflow datasets

        Raises
        ------
        ValueError
            If cache is None
        TypeError
            If cache is wrong type
        """
        cache = _check_boolean(cache, "cache")
        prefetch = _check_type(prefetch, int, "prefetch", allow_None=True)
        self.remaining = [
            tf.data.Dataset.from_tensor_slices(fiducial)
            for fiducial in self.fiducial[self.n_d:].reshape(
                self.remaining_batch_shape + self.input_shape)
        ]

        self.main = [
            tf.data.Dataset.from_tensor_slices((fiducial, derivative))
            for fiducial, derivative in zip(
                self.fiducial[:self.n_d].reshape(self.batch_shape +
                                                 self.input_shape),
                self.derivative.reshape(self.batch_shape + self.input_shape +
                                        (self.n_params, )))
        ]

        if cache:
            self.remaining = [
                remaining.cache() for remaining in self.remaining
            ]
            self.main = [main.cache() for main in self.main]
        if prefetch is not None:
            self.remaining = [
                remaining.prefetch(prefetch) for remaining in self.remaining
            ]
            self.main = [main.prefetch(prefetch) for main in self.main]

        self.main = [main.repeat().as_numpy_iterator() for main in self.main]
        self.remaining = [
            remaining.repeat().as_numpy_iterator()
            for remaining in self.remaining
        ]

        if self.validate:
            self.validation_remaining = [
                tf.data.Dataset.from_tensor_slices(fiducial)
                for fiducial in self.validation_fiducial[self.n_d:].reshape(
                    self.remaining_batch_shape + self.input_shape)
            ]

            self.validation_main = [
                tf.data.Dataset.from_tensor_slices((fiducial, derivative))
                for fiducial, derivative in zip(
                    self.validation_fiducial[:self.n_d].reshape(
                        self.batch_shape + self.input_shape),
                    self.validation_derivative.reshape(self.batch_shape +
                                                       self.input_shape +
                                                       (self.n_params, )))
            ]

            if cache:
                self.validation_remaining = [
                    remaining.cache()
                    for remaining in self.validation_remaining
                ]
                self.validation_main = [
                    main.cache() for main in self.validation_main
                ]
            if prefetch is not None:
                self.validation_remaining = [
                    remaining.prefetch(prefetch)
                    for remaining in self.validation_remaining
                ]
                self.validation_main = [
                    main.prefetch(prefetch) for main in self.validation_main
                ]

            self.validation_main = [
                main.repeat().as_numpy_iterator()
                for main in self.validation_main
            ]
            self.validation_remaining = [
                remaining.repeat().as_numpy_iterator()
                for remaining in self.validation_remaining
            ]