def load_model(self, ensemble_dir): """Load `self.model` from json file. See tf.keras.models.model_from_json() Note: custom objects can be loaded by calling this function from inside a `with tf.keras.utils.CustomObjectScope({'ObjectName': ObjectName}:` scope. Does not load ensemble weights, see `self.save_ensemble()`. Args: ensemble_dir: path to ensemble weights. """ # TODO(basv) consider tf.distribute.Strategy api and model building. with tf.io.gfile.GFile(os.path.join(ensemble_dir, "model.json"), "r") as f: json_string = str(f.read()) with bnnmodel.bnn_scope(): self.model = tf.keras.models.model_from_json(json_string)
def __call__(self, layer): """Add a prior to the newly constructed input layer. Args: layer: tf.keras.layers.Layer that has just been constructed (not built, no graph). Returns: layer_out: the layer with a suitable prior added. """ if not layer.trainable: return layer # Obtain serialized layer representation and replace priors config = layer.get_config() self._update_prior(layer, config) # Reconstruct prior from updated serialized representation with bnnmodel.bnn_scope(): layer_out = type(layer).from_config(config) return layer_out