def _set_prebuilt_dataset(self, main, remaining, validation_main, validation_remaining): """Set preconstructed dataset iterators Parameters ---------- main : list of tf.data.Dataset().as_numpy_iterators() The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for fitting). These are served ``n_per_device`` at at time as a numpy iterator from a TensorFlow dataset. remaining : list of tf.data.Dataset().as_numpy_iterators() The ``n_s - n_d`` simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs with a derivative counterpart (for fitting). These are served ``n_per_device`` at at time as a numpy iterator from a TensorFlow dataset. validation_main : list of tf.data.Dataset().as_numpy_iterators() The simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs and their derivatives with respect to the physical model parameters (for validation). These are served ``n_per_device`` at at time as a numpy iterator from a TensorFlow dataset. validation_remaining : list of tf.data.Dataset().as_numpy_iterators() The ``n_s - n_d`` simulations generated at the fiducial model parameter values used for calculating the covariance of network outputs with a derivative counterpart (for validation). Served ``n_per_device`` at time as a numpy iterator from a TensorFlow dataset. Raises ------ ValueError if main or remaining are None ValueError if length of any input list is not equal to number of devices TypeError if any input is not a list """ self.main = _check_type(main, list, "main", shape=self.n_devices) self.remaining = _check_type(remaining, list, "remaining", shape=self.n_devices) if ((validation_main is not None) and (validation_remaining is not None)): self.validation_main = _check_type(validation_main, list, "validation_main", shape=self.n_devices) self.validation_remaining = _check_type(validation_remaining, list, "validation_remaining", shape=self.n_devices) self.validate = True
def _set_devices(self, devices, n_per_device): """Checks that devices exist and that reshaping onto devices can occur Due to the aggregation then balanced splits must be made between the different devices and so these are checked. Parameters ---------- devices: list A list of the available jax devices (from ``jax.devices()``) n_per_device: int Number of simulations to handle at once, this should be as large as possible without letting the memory overflow for the best performance Raises ------ ValueError If ``devices`` or ``n_per_device`` are None ValueError If balanced splitting cannot be done TypeError If ``devices`` is not a list and if ``n_per_device`` is not an int """ self.devices = _check_devices(devices) self.n_devices = len(self.devices) self.n_per_device = _check_type(n_per_device, int, "n_per_device") if self.n_s == self.n_d: _check_splitting(self.n_s, "n_s and n_d", self.n_devices, self.n_per_device) else: _check_splitting(self.n_s, "n_s", self.n_devices, self.n_per_device) _check_splitting(self.n_d, "n_d", self.n_devices, self.n_per_device)
def _set_dataset(self, prefetch=None, cache=None): """Overwritten function to prevent building dataset, does list check Raises ------ ValueError if fiducial or derivative are None ValueError if any dataset has wrong shape TypeError if any dataset has wrong type """ _check_type(self.fiducial, list, "fiducial", shape=self.n_devices) _check_type(self.derivative, list, "derivative", shape=self.n_devices) if self.validate: _check_type(self.validation_fiducial, list, "validation_fiducial", shape=self.n_devices) _check_type(self.validation_derivative, list, "validation_derivative", shape=self.n_devices)
def _setup_progress_bar(self, print_rate, max_iterations): """Construct progress bar Parameters ---------- print_rate : int or None The rate at which the progress bar is updated (no bar if None) max_iterations : int The maximum number of iterations, used to setup bar upper limit Returns ------- progress bar or None: The TQDM progress bar object int or None: The print rate (after checking for int or None) int or None: The difference between the max_iterations and the print rate Raises ------ TypeError: If ``print_rate`` is not an integer """ print_rate = _check_type(print_rate, int, "print_rate", allow_None=True) if print_rate is not None: if max_iterations < 10000: pbar = tqdm.tqdm(total=max_iterations) else: pbar = tqdm.tqdm() remainder = max_iterations % print_rate return pbar, print_rate, remainder else: return None, None, None
def fit(self, λ, ϵ, rng=None, patience=100, min_iterations=100, max_iterations=int(1e5), print_rate=None, best=True): """Fitting routine for the IMNN Parameters ---------- λ : float Coupling strength of the regularisation ϵ : float Closeness criterion describing how close to the 1 the determinant of the covariance (and inverse covariance) of the network outputs is desired to be rng : int(2,) or None, default=None Stateless random number generator patience : int, default=10 Number of iterations where there is no increase in the value of the determinant of the Fisher information matrix, used for early stopping min_iterations : int, default=100 Number of iterations that should be run before considering early stopping using the patience counter max_iterations : int, default=int(1e5) Maximum number of iterations to run the fitting procedure for print_rate : int or None, default=None, Number of iterations before updating the progress bar whilst fitting. There is a performance hit from updating the progress bar more often and there is a large performance hit from using the progress bar at all. (Possible ``RET_CHECK`` failure if ``print_rate`` is not ``None`` when using GPUs). For this reason it is set to None as default best : bool, default=True Whether to set the network parameter attribute ``self.w`` to the parameter values that obtained the maximum determinant of the Fisher information matrix or the parameter values at the final iteration of fitting Example ------- We are going to summarise the mean and variance of some random Gaussian noise with 10 data points per example using an AggregatedSimulatorIMNN. In this case we are going to generate the simulations on-the-fly with a simulator written in jax (from the examples directory). These simulations will be generated on-the-fly and passed through the network on each of the GPUs in ``jax.devices("gpu")`` and we will make 100 simulations on each device at a time. The main computation will be done on the CPU. We will use 1000 simulations to estimate the covariance of the network outputs and the derivative of the mean of the network outputs with respect to the model parameters (Gaussian mean and variance) and generate the simulations at a fiducial μ=0 and Σ=1. The network will be a stax model with hidden layers of ``[128, 128, 128]`` activated with leaky relu and outputting 2 summaries. Optimisation will be via Adam with a step size of ``1e-3``. Rather arbitrarily we'll set the regularisation strength and covariance identity constraint to λ=10 and ϵ=0.1 (these are relatively unimportant for such an easy model). .. code-block:: python import jax import jax.numpy as np from jax.experimental import stax, optimizers from imnn import AggregatedSimulatorIMNN rng = jax.random.PRNGKey(0) n_s = 1000 n_d = 1000 n_params = 2 n_summaries = 2 input_shape = (10,) θ_fid = np.array([0., 1.]) def simulator(rng, θ): return θ[0] + jax.random.normal( rng, shape=input_shape) * np.sqrt(θ[1]) model = stax.serial( stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(128), stax.LeakyRelu, stax.Dense(n_summaries)) optimiser = optimizers.adam(step_size=1e-3) λ = 10. ϵ = 0.1 model_key, fit_key = jax.random.split(rng) host = jax.devices("cpu")[0] devices = jax.devices("gpu") n_per_device = 100 imnn = AggregatedSimulatorIMNN( n_s=n_s, n_d=n_d, n_params=n_params, n_summaries=n_summaries, input_shape=input_shape, θ_fid=θ_fid, model=model, optimiser=optimiser, key_or_state=model_key, simulator=simulator, host=host, devices=devices, n_per_device=n_per_device) imnn.fit(λ, ϵ, rng=fit_key, min_iterations=1000, patience=250, print_rate=None) Notes ----- A minimum number of interations should be be run before stopping based on a maximum determinant of the Fisher information achieved since the loss function has dual objectives. Since the determinant of the covariance of the network outputs is forced to 1 quickly, this can be at the detriment to the value of the determinant of the Fisher information matrix early in the fitting procedure. For this reason starting early stopping after the covariance has converged is advised. This is not currently implemented but could be considered in the future. The best fit network parameter values are probably not the most representative set of parameters when simulating on-the-fly since there is a high chance of a statistically overly-informative set of data being generated. Instead, if using :func:`~imnn.AggregatedSimulatorIMNN.fit()` consider using ``best=False`` which sets ``self.w=self.final_w`` which are the network parameter values obtained in the last iteration. Also consider using a larger ``patience`` value if using :func:`~imnn.SimulatorIMNN.fit()` to overcome the fact that a flukish high value for the determinant might have been obtained due to the realisation of the dataset. Raises ------ TypeError If any input has the wrong type ValueError If any input (except ``rng``) are ``None`` ValueError If ``rng`` has the wrong shape ValueError If ``rng`` is ``None`` but simulating on-the-fly Methods ------- get_keys_and_params: Jitted collection of parameters and random numbers calculate_loss: Returns the jitted gradient of the loss function wrt summaries validation_loss: Jitted loss and auxillary statistics from validation set Todo ---- - ``rng`` is currently only used for on-the-fly simulation but could easily be updated to allow for stochastic models - Automatic detection of convergence based on value ``r`` when early stopping can be started """ @jax.jit def get_keys_and_params(rng, state): """Jitted collection of parameters and random numbers Parameters ---------- rng : int(2,) or None, default=None Stateless random number generator state : :obj:state The optimiser state used for updating the network parameters and optimisation algorithm Returns ------- int(2,) or None, default=None: Stateless random number generator int(2,) or None, default=None: Stateless random number generator for training int(2,) or None, default=None: Stateless random number generator for validation list: Network parameter values """ rng, training_key, validation_key = self._get_fitting_keys(rng) w = self._get_parameters(state) return rng, training_key, validation_key, w @jax.jit @partial(jax.grad, argnums=(0, 1), has_aux=True) def calculate_loss(summaries, summary_derivatives): """Returns the jitted gradient of the loss function wrt summaries Used to calculate the gradient of the loss function wrt summaries and derivatives of the summaries with respect to model parameters which will be used to calculate the aggregated gradient of the Fisher information with respect to the network parameters via the chain rule. Parameters ---------- summaries : float(n_s, n_summaries) The network outputs summary_derivatives : float(n_d, n_summaries, n_params) The derivative of the network outputs wrt the model parameters Returns ------- tuple: Gradient of the loss function with respect to network outputs and their derivatives with respect to physical model parameters tuple: Fitting statistics calculated on a single iteration - **F** *(float(n_params, n_params))* -- Fisher information matrix - **C** *(float(n_summaries, n_summaries))* -- covariance of network outputs - **invC** *(float(n_summaries, n_summaries))* -- inverse covariance of network outputs - **Λ2** *(float)* -- covariance regularisation - **r** *(float)* -- regularisation coupling strength """ return self._calculate_loss(summaries, summary_derivatives, λ, α) @jax.jit def validation_loss(summaries, derivatives): """Jitted loss and auxillary statistics from validation set Parameters ---------- summaries : float(n_s, n_summaries) The network outputs summary_derivatives : float(n_d, n_summaries, n_params) The derivative of the network outputs wrt the model parameters Returns ------- tuple: Fitting statistics calculated on a single validation iteration - **F** *(float(n_params, n_params))* -- Fisher information matrix - **C** *(float(n_summaries, n_summaries))* -- covariance of network outputs - **invC** *(float(n_summaries, n_summaries))* -- inverse covariance of network outputs - **Λ2** *(float)* -- covariance regularisation - **r** *(float)* -- regularisation coupling strength """ F, C, invC, *_ = self._calculate_F_statistics( summaries, derivatives) _Λ2 = self._get_regularisation(C, invC) _r = self._get_regularisation_strength(_Λ2, λ, α) return (F, C, invC, _Λ2, _r) λ = _check_type(λ, float, "λ") ϵ = _check_type(ϵ, float, "ϵ") α = self.get_α(λ, ϵ) patience = _check_type(patience, int, "patience") min_iterations = _check_type(min_iterations, int, "min_iterations") max_iterations = _check_type(max_iterations, int, "max_iterations") best = _check_boolean(best, "best") if self.simulate and (rng is None): raise ValueError("`rng` is necessary when simulating.") rng = _check_input(rng, (2, ), "rng", allow_None=True) max_detF, best_w, detF, detC, detinvC, Λ2, r, counter, \ patience_counter, state, rng = self._set_inputs( rng, max_iterations) pbar, print_rate, remainder = self._setup_progress_bar( print_rate, max_iterations) while self._fit_cond((max_detF, best_w, detF, detC, detinvC, Λ2, r, counter, patience_counter, state, rng), patience=patience, max_iterations=max_iterations): rng, training_key, validation_key, w = get_keys_and_params( rng, state) summaries, summary_derivatives = self.get_summaries( w=w, key=training_key) dΛ_dx, results = calculate_loss(summaries, summary_derivatives) grad = self.get_gradient(dΛ_dx, w, key=training_key) state = self._update(counter, grad, state) w = self._get_parameters(state) detF, detC, detinvC, Λ2, r = self._update_history( results, (detF, detC, detinvC, Λ2, r), counter, 0) if self.validate: summaries, summary_derivatives = self.get_summaries( w=w, key=training_key, validate=True) results = validation_loss(summaries, summary_derivatives) detF, detC, detinvC, Λ2, r = self._update_history( results, (detF, detC, detinvC, Λ2, r), counter, 1) _detF = np.linalg.det(results[0]) patience_counter, counter, _, max_detF, __, best_w = \ jax.lax.cond( np.greater(_detF, max_detF), self._update_loop_vars, lambda inputs: self._check_loop_vars( inputs, min_iterations), (patience_counter, counter, _detF, max_detF, w, best_w)) self._update_progress_bar(pbar, counter, patience_counter, max_detF, detF[counter], detC[counter], detinvC[counter], Λ2[counter], r[counter], print_rate, max_iterations, remainder) counter += 1 self._update_progress_bar(pbar, counter, patience_counter, max_detF, detF[counter - 1], detC[counter - 1], detinvC[counter - 1], Λ2[counter - 1], r[counter - 1], print_rate, max_iterations, remainder, close=True) self.history["max_detF"] = max_detF self.best_w = best_w self._set_history((detF[:counter], detC[:counter], detinvC[:counter], Λ2[:counter], r[:counter])) self.state = state self.final_w = self._get_parameters(self.state) if best: w = self.best_w else: w = self.final_w self.set_F_statistics(w, key=rng)
def _set_dataset(self, prefetch, cache): """ Transforms the data into lists of tensorflow dataset iterators Parameters ---------- prefetch : tf.data.AUTOTUNE or int or None How many simulation to prefetch in the tensorflow dataset cache : bool Whether to cache simulations in the tensorflow datasets Raises ------ ValueError If ``cache`` and/or ``prefetch`` is None TypeError If ``cache`` and/or ``prefetch`` is wrong type """ cache = _check_boolean(cache, "cache") prefetch = _check_type(prefetch, int, "prefetch", allow_None=True) self.fiducial = self.fiducial.reshape(self.fiducial_batch_shape + self.input_shape) self.fiducial = [ tf.data.Dataset.from_tensor_slices(fiducial) for fiducial in self.fiducial ] self.derivative = self.derivative.reshape(self.derivative_batch_shape + self.input_shape) self.derivative = [ tf.data.Dataset.from_tensor_slices(derivative) for derivative in self.derivative ] if cache: self.fiducial = [fiducial.cache() for fiducial in self.fiducial] self.derivative = [ derivative.cache() for derivative in self.derivative ] if prefetch is not None: self.fiducial = [ fiducial.prefetch(prefetch) for fiducial in self.fiducial ] self.derivative = [ derivative.prefetch(prefetch) for derivative in self.derivative ] self.fiducial = [ fiducial.repeat().as_numpy_iterator() for fiducial in self.fiducial ] self.derivative = [ derivative.repeat().as_numpy_iterator() for derivative in self.derivative ] if self.validate: self.validation_fiducial = self.validation_fiducial.reshape( self.fiducial_batch_shape + self.input_shape) self.validation_fiducial = [ tf.data.Dataset.from_tensor_slices(fiducial) for fiducial in self.validation_fiducial ] self.validation_derivative = self.validation_derivative.reshape( self.derivative_batch_shape + self.input_shape) self.validation_derivative = [ tf.data.Dataset.from_tensor_slices(derivative) for derivative in self.validation_derivative ] if cache: self.validation_fiducial = [ fiducial.cache() for fiducial in self.validation_fiducial ] self.validation_derivative = [ derivative.cache() for derivative in self.validation_derivative ] if prefetch is not None: self.validation_fiducial = [ fiducial.prefetch(prefetch) for fiducial in self.validation_fiducial ] self.validation_derivative = [ derivative.prefetch(prefetch) for derivative in self.validation_derivative ] self.validation_fiducial = [ fiducial.repeat().as_numpy_iterator() for fiducial in self.validation_fiducial ] self.validation_derivative = [ derivative.repeat().as_numpy_iterator() for derivative in self.validation_derivative ]
def _set_dataset(self, prefetch, cache): """ Collates data into loopable tensorflow dataset iterations Parameters ---------- prefetch : tf.data.AUTOTUNE or int or None How many simulation to prefetch in the tensorflow dataset cache : bool Whether to cache simulations in the tensorflow datasets Raises ------ ValueError If cache is None TypeError If cache is wrong type """ cache = _check_boolean(cache, "cache") prefetch = _check_type(prefetch, int, "prefetch", allow_None=True) self.remaining = [ tf.data.Dataset.from_tensor_slices(fiducial) for fiducial in self.fiducial[self.n_d:].reshape( self.remaining_batch_shape + self.input_shape) ] self.main = [ tf.data.Dataset.from_tensor_slices((fiducial, derivative)) for fiducial, derivative in zip( self.fiducial[:self.n_d].reshape(self.batch_shape + self.input_shape), self.derivative.reshape(self.batch_shape + self.input_shape + (self.n_params, ))) ] if cache: self.remaining = [ remaining.cache() for remaining in self.remaining ] self.main = [main.cache() for main in self.main] if prefetch is not None: self.remaining = [ remaining.prefetch(prefetch) for remaining in self.remaining ] self.main = [main.prefetch(prefetch) for main in self.main] self.main = [main.repeat().as_numpy_iterator() for main in self.main] self.remaining = [ remaining.repeat().as_numpy_iterator() for remaining in self.remaining ] if self.validate: self.validation_remaining = [ tf.data.Dataset.from_tensor_slices(fiducial) for fiducial in self.validation_fiducial[self.n_d:].reshape( self.remaining_batch_shape + self.input_shape) ] self.validation_main = [ tf.data.Dataset.from_tensor_slices((fiducial, derivative)) for fiducial, derivative in zip( self.validation_fiducial[:self.n_d].reshape( self.batch_shape + self.input_shape), self.validation_derivative.reshape(self.batch_shape + self.input_shape + (self.n_params, ))) ] if cache: self.validation_remaining = [ remaining.cache() for remaining in self.validation_remaining ] self.validation_main = [ main.cache() for main in self.validation_main ] if prefetch is not None: self.validation_remaining = [ remaining.prefetch(prefetch) for remaining in self.validation_remaining ] self.validation_main = [ main.prefetch(prefetch) for main in self.validation_main ] self.validation_main = [ main.repeat().as_numpy_iterator() for main in self.validation_main ] self.validation_remaining = [ remaining.repeat().as_numpy_iterator() for remaining in self.validation_remaining ]