コード例 #1
0
  def load_model(self, ensemble_dir):
    """Load `self.model` from json file.

    See tf.keras.models.model_from_json()

    Note: custom objects can be loaded by calling this function from inside a
    `with tf.keras.utils.CustomObjectScope({'ObjectName': ObjectName}:` scope.
    Does not load ensemble weights, see `self.save_ensemble()`.

    Args:
      ensemble_dir: path to ensemble weights.
    """
    # TODO(basv) consider tf.distribute.Strategy api and model building.
    with tf.io.gfile.GFile(os.path.join(ensemble_dir, "model.json"), "r") as f:
      json_string = str(f.read())

    with bnnmodel.bnn_scope():
      self.model = tf.keras.models.model_from_json(json_string)
コード例 #2
0
    def __call__(self, layer):
        """Add a prior to the newly constructed input layer.

    Args:
      layer: tf.keras.layers.Layer that has just been constructed (not built, no
        graph).

    Returns:
      layer_out: the layer with a suitable prior added.
    """
        if not layer.trainable:
            return layer

        # Obtain serialized layer representation and replace priors
        config = layer.get_config()
        self._update_prior(layer, config)

        # Reconstruct prior from updated serialized representation
        with bnnmodel.bnn_scope():
            layer_out = type(layer).from_config(config)

        return layer_out