def test_clone_subclassed(self): class TestModel(tf.keras.Model): def __init__(self): super(TestModel, self).__init__() self.hidden = tf.keras.layers.Dense(10, activation='relu') self.out1 = tf.keras.layers.Dense(10, name='a') self.out2 = tf.keras.layers.Dense(10, name='nolabel') def call(self, inputs): x = self.hidden(inputs) return [self.out1(x), self.out2(x)] def get_config(self): return dict() @staticmethod def from_config(config): return TestModel() model1 = TestModel() input_shape = (None, 10) model1.build(input_shape) data = tf.random.normal((20, 10)) model2 = bnnmodel.clone_model_and_weights(model1, input_shape) pred1 = model1(data) pred2 = model2(data) self.assertAllClose(pred1, pred2, msg='model2 output differs from model1')
def __init__(self, model, input_shape=None, weights_list=None, clone_model=True): """Initialize an empirical ensemble. Args: model: tf.keras.models.Model, tf.keras.models.Sequential, or a factory that instantiates an object that behaves as tf.keras.Model. The latter case suits e.g. a `lambda: YourModelClass(param=value,...)`. input_shape: tf.keras.models.Model input_shape to be used with Model.build function. Note: currently cannot be None; TODO(nowozin): we hope this requirement will change in the future, see b/132994200 weights_list: list of weights compatible with `model` clone_model: bool, default True. If using keras model.fit(), set to false. Raises: ValueError: unsupported argument. """ if clone_model: if input_shape is None: raise ValueError("input_shape cannot be None in EmpiricalEnsemble " "constructor") if isinstance(model, tf.keras.Model): self.model = bnnmodel.clone_model_and_weights(model, input_shape) else: self.model = model() self.model.build(input_shape) else: self.model = model self.input_shape = input_shape self.weights_list = weights_list if weights_list is not None else []
def test_clone(self): reg = prior.SpikeAndSlabRegularizer(weight=1.0) model1 = tf.keras.models.Sequential([ tf.keras.layers.Dense(10, activation='relu', kernel_regularizer=reg, bias_regularizer=reg), tf.keras.layers.Dense(10)]) ndata = 256 data = tf.random.normal((ndata, 10)) pred1 = model1(data) model2 = bnnmodel.clone_model_and_weights(model1, (None, 10)) model1 = None pred2 = model2(data) self.assertAllClose(pred1, pred2, msg='Model cloning failed.')
def fit(self, dataset, y=None, statistics=None, epochs=1, validation_split=0., validation_data=None, validation_freq=1, initial_epoch=0, **fit_kwargs): """Trains ensemble members using a different initialization for each model. Note: alpha functionality, experimental and prone to change. Arguments: dataset: Either: - tf.data.Dataset yielding (inputs,outputs, [sample_weight]) tuples suitable for self.model.fit() API. Input/outputs can be singular, tuples (ordered according to self.model.inputs/outputs) or dict (with keys matching from self.model.input/output_names). - Tuple (inputs, outputs) suitable for self.model.fit(). - Inputs suitable for self.model.predict(inputs). y: outputs in case dataset is input tensor. statistics: mean statistics to evaluate on output, see self.model.compile(metrics=...) API: - For single-output model (e.g. keras.Sequential()): [Statistic, ...]. - Multiple-output model (e.g. keras.Model(outputs=(...,))): tuple([Statistic, ...], ...) in order of outputs. - Multiple named output model: (e.g. keras.Model(outputs={...})): dict(output_name=[Statistic, ...], ...) epochs: Number of epochs to train for. validation_split: Float between 0 and 1. Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, will not train on it, and will evaluate the loss and any model metrics on this data at the end of each epoch. The validation data is selected from the last samples in the `x` and `y` data provided, before shuffling. This argument is not supported when `x` is a dataset, generator or `keras.utils.Sequence` instance. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. `validation_data` will override `validation_split`. `validation_data` could be: - tuple `(x_val, y_val)` of Numpy arrays or tensors - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays - dataset For the first two cases, `batch_size` must be provided. For the last case, `validation_steps` must be provided. validation_freq: Only relevant if validation data is provided. Integer or `collections.Container` instance (e.g. list, tuple, etc.). If an integer, specifies how many training epochs to run before a new validation run is performed, e.g. `validation_freq=2` runs validation every 2 epochs. initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run). **fit_kwargs: Additional arguments. Returns: history: A list of statistics results per epoch following format of `statistics`. Raises: RuntimeError: If the model was never compiled. ValueError: In case of mismatch between the provided input data and what the model expects. """ # Check unsupported tf.keras.model.fit() arguments. fit_verbose = fit_kwargs.pop("verbose", None) or 0 self._validate_fit_params(validation_split, validation_freq) # Collect output statistics. if hasattr(self.model, "output_names"): output_names = self.model.output_names else: output_names = None ordered_stats_lists = _statistics_to_ordered_tuple( statistics, output_names) for stats_list in ordered_stats_lists: for stat in stats_list: if not isinstance(stat, stats.MeanStatistic): # TODO(basv): automaticaly wrap SampleStatistics with MeanStatistic raise ValueError( "Invalid entry in statistics argument: only " "MeanStatistics are supported.") # Initialize new members: # TODO(basv): parametrize re-initialization or shared initialization. for member_index in range(len(self.weights_list), self.n_members): assert initial_epoch == 0, "New members can only be initialized at start." cloned = bnnmodel.clone_model_and_weights(self.model, self.input_shape) self.weights_list.append(cloned.get_weights()) history = [] # Add progres bar with tqdm(total=epochs * len(self.weights_list), position=0, leave=True, unit="epoch") as progress_bar: # Loop over epochs for epoch in range(initial_epoch, epochs): # For each member for member_index, model in enumerate(self.iter_members()): self._epoch_progress_logger(epoch, member_index, progress_bar, fit_verbose) # TODO(basv): check if member is already trained up to epoch. # TODO(basv): consider member callback state switching? # Fit for one epoch if y is not None: dataset_in, y_in = dataset, y elif isinstance(dataset, tuple) and len(dataset) == 2: dataset_in, y_in = dataset else: dataset_in, y_in = dataset, None model.fit(dataset_in, y_in, epochs=epoch + 1, initial_epoch=epoch, verbose=fit_verbose, **fit_kwargs) self.weights_list[member_index] = model.get_weights() self._optimizers_list[ member_index] = model.optimizer.get_weights() if validation_data is not None and statistics is not None: # TODO(basv): statistics per dataset? # TODO(basv): unlabeled statistics? # Report ensemble statistics if epoch % validation_freq == 0: # TODO(basv): verify validation_freq # Update statistics for stats_list in ordered_stats_lists: for stat in stats_list: stat.reset() results = self.evaluate_ensemble( validation_data, statistics) history.append(results) # [Handle Callbacks] # TODO(basv): implement some kind of ensemble callback structure. # TODO(basv): implement ensemble checkpointing. # TODO(basv): implement parallel workers and evaluate worker. return history