def test_clone_subclassed(self):
    class TestModel(tf.keras.Model):

      def __init__(self):
        super(TestModel, self).__init__()
        self.hidden = tf.keras.layers.Dense(10, activation='relu')
        self.out1 = tf.keras.layers.Dense(10, name='a')
        self.out2 = tf.keras.layers.Dense(10, name='nolabel')

      def call(self, inputs):
        x = self.hidden(inputs)
        return [self.out1(x), self.out2(x)]

      def get_config(self):
        return dict()

      @staticmethod
      def from_config(config):
        return TestModel()

    model1 = TestModel()
    input_shape = (None, 10)
    model1.build(input_shape)

    data = tf.random.normal((20, 10))
    model2 = bnnmodel.clone_model_and_weights(model1, input_shape)

    pred1 = model1(data)
    pred2 = model2(data)
    self.assertAllClose(pred1, pred2, msg='model2 output differs from model1')
  def __init__(self,
               model,
               input_shape=None,
               weights_list=None,
               clone_model=True):
    """Initialize an empirical ensemble.

    Args:
      model: tf.keras.models.Model, tf.keras.models.Sequential, or a factory
        that instantiates an object that behaves as tf.keras.Model.
        The latter case suits e.g. a `lambda: YourModelClass(param=value,...)`.
      input_shape: tf.keras.models.Model input_shape to be used with Model.build
        function.  Note: currently cannot be None; TODO(nowozin): we hope this
          requirement will change in the future, see b/132994200
      weights_list: list of weights compatible with `model`
      clone_model: bool, default True. If using keras model.fit(), set to false.

    Raises:
      ValueError: unsupported argument.
    """
    if clone_model:
      if input_shape is None:
        raise ValueError("input_shape cannot be None in EmpiricalEnsemble "
                         "constructor")

      if isinstance(model, tf.keras.Model):
        self.model = bnnmodel.clone_model_and_weights(model, input_shape)
      else:
        self.model = model()
        self.model.build(input_shape)
    else:
      self.model = model
    self.input_shape = input_shape

    self.weights_list = weights_list if weights_list is not None else []
  def test_clone(self):
    reg = prior.SpikeAndSlabRegularizer(weight=1.0)
    model1 = tf.keras.models.Sequential([
        tf.keras.layers.Dense(10, activation='relu',
                              kernel_regularizer=reg,
                              bias_regularizer=reg),
        tf.keras.layers.Dense(10)])

    ndata = 256
    data = tf.random.normal((ndata, 10))
    pred1 = model1(data)

    model2 = bnnmodel.clone_model_and_weights(model1, (None, 10))
    model1 = None
    pred2 = model2(data)

    self.assertAllClose(pred1, pred2, msg='Model cloning failed.')
    def fit(self,
            dataset,
            y=None,
            statistics=None,
            epochs=1,
            validation_split=0.,
            validation_data=None,
            validation_freq=1,
            initial_epoch=0,
            **fit_kwargs):
        """Trains ensemble members using a different initialization for each model.

    Note: alpha functionality, experimental and prone to change.

    Arguments:
      dataset: Either: - tf.data.Dataset yielding (inputs,outputs,
        [sample_weight]) tuples suitable for self.model.fit() API. Input/outputs
        can be singular, tuples (ordered according to self.model.inputs/outputs)
        or dict (with keys matching from self.model.input/output_names). - Tuple
        (inputs, outputs) suitable for self.model.fit(). - Inputs suitable for
        self.model.predict(inputs).
      y: outputs in case dataset is input tensor.
      statistics: mean statistics to evaluate on output, see
          self.model.compile(metrics=...) API:
        - For single-output model (e.g. keras.Sequential()): [Statistic, ...].
        - Multiple-output model (e.g. keras.Model(outputs=(...,))):
          tuple([Statistic, ...], ...) in order of outputs.
        - Multiple named output model: (e.g. keras.Model(outputs={...})):
          dict(output_name=[Statistic, ...], ...)
      epochs: Number of epochs to train for.
      validation_split: Float between 0 and 1. Fraction of the training data to
        be used as validation data. The model will set apart this fraction of
        the training data, will not train on it, and will evaluate the loss and
        any model metrics on this data at the end of each epoch. The validation
        data is selected from the last samples in the `x` and `y` data provided,
        before shuffling. This argument is not supported when `x` is a dataset,
        generator or `keras.utils.Sequence` instance.
      validation_data: Data on which to evaluate the loss and any model metrics
        at the end of each epoch. The model will not be trained on this data.
        `validation_data` will override `validation_split`.
          `validation_data` could be: - tuple `(x_val, y_val)` of Numpy arrays
            or tensors - tuple `(x_val, y_val, val_sample_weights)` of Numpy
            arrays - dataset For the first two cases, `batch_size` must be
            provided. For the last case, `validation_steps` must be provided.
      validation_freq: Only relevant if validation data is provided. Integer or
        `collections.Container` instance (e.g. list, tuple, etc.). If an
        integer, specifies how many training epochs to run before a new
        validation run is performed, e.g. `validation_freq=2` runs validation
        every 2 epochs.
      initial_epoch: Integer. Epoch at which to start training (useful for
        resuming a previous training run).
      **fit_kwargs: Additional arguments.

    Returns:
      history:  A list of statistics results per epoch following format of
           `statistics`.

    Raises:
        RuntimeError: If the model was never compiled.
        ValueError: In case of mismatch between the provided input data
            and what the model expects.
    """
        # Check unsupported tf.keras.model.fit() arguments.
        fit_verbose = fit_kwargs.pop("verbose", None) or 0

        self._validate_fit_params(validation_split, validation_freq)
        # Collect output statistics.
        if hasattr(self.model, "output_names"):
            output_names = self.model.output_names
        else:
            output_names = None
        ordered_stats_lists = _statistics_to_ordered_tuple(
            statistics, output_names)
        for stats_list in ordered_stats_lists:
            for stat in stats_list:
                if not isinstance(stat, stats.MeanStatistic):
                    # TODO(basv): automaticaly wrap SampleStatistics with MeanStatistic
                    raise ValueError(
                        "Invalid entry in statistics argument: only "
                        "MeanStatistics are supported.")

        # Initialize new members:
        # TODO(basv): parametrize re-initialization or shared initialization.
        for member_index in range(len(self.weights_list), self.n_members):
            assert initial_epoch == 0, "New members can only be initialized at start."
            cloned = bnnmodel.clone_model_and_weights(self.model,
                                                      self.input_shape)
            self.weights_list.append(cloned.get_weights())

        history = []

        # Add progres bar
        with tqdm(total=epochs * len(self.weights_list),
                  position=0,
                  leave=True,
                  unit="epoch") as progress_bar:

            # Loop over epochs
            for epoch in range(initial_epoch, epochs):
                # For each member
                for member_index, model in enumerate(self.iter_members()):

                    self._epoch_progress_logger(epoch, member_index,
                                                progress_bar, fit_verbose)
                    # TODO(basv): check if member is already trained up to epoch.
                    # TODO(basv): consider member callback state switching?

                    # Fit for one epoch
                    if y is not None:
                        dataset_in, y_in = dataset, y
                    elif isinstance(dataset, tuple) and len(dataset) == 2:
                        dataset_in, y_in = dataset
                    else:
                        dataset_in, y_in = dataset, None

                    model.fit(dataset_in,
                              y_in,
                              epochs=epoch + 1,
                              initial_epoch=epoch,
                              verbose=fit_verbose,
                              **fit_kwargs)
                    self.weights_list[member_index] = model.get_weights()
                    self._optimizers_list[
                        member_index] = model.optimizer.get_weights()

                if validation_data is not None and statistics is not None:
                    # TODO(basv): statistics per dataset?
                    # TODO(basv): unlabeled statistics?
                    # Report ensemble statistics
                    if epoch % validation_freq == 0:  # TODO(basv): verify validation_freq
                        # Update statistics
                        for stats_list in ordered_stats_lists:
                            for stat in stats_list:
                                stat.reset()
                            results = self.evaluate_ensemble(
                                validation_data, statistics)
                            history.append(results)

                # [Handle Callbacks]
                # TODO(basv): implement some kind of ensemble callback structure.
                # TODO(basv): implement ensemble checkpointing.
                # TODO(basv): implement parallel workers and evaluate worker.

        return history