Ejemplo n.º 1
0
    def read_serialized_keras_model(self, ckpt_path, model):
        """
        Returns serialized keras model. On Databricks, only TFKeras is supported, not BareKeras.
        The parameter `model` is for providing the model structure when the checkpoint file only
        contains model weights.
        """
        import tensorflow
        from tensorflow import keras
        from horovod.spark.keras.util import TFKerasUtil

        if LooseVersion(tensorflow.__version__) < LooseVersion("2.0.0"):
            model.load_weights(ckpt_path)
        else:
            model = keras.models.load_model(ckpt_path)
        return TFKerasUtil.serialize_model(model)
Ejemplo n.º 2
0
    def read_serialized_keras_model(self, ckpt_path, model, custom_objects):
        """
        Returns serialized keras model.
        The parameter `model` is for providing the model structure when the checkpoint file only
        contains model weights.
        """
        import tensorflow
        from tensorflow import keras
        from horovod.spark.keras.util import TFKerasUtil

        if LooseVersion(tensorflow.__version__) < LooseVersion("2.0.0"):
            model.load_weights(ckpt_path)
        else:
            with keras.utils.custom_object_scope(custom_objects):
                model = keras.models.load_model(ckpt_path)
        return TFKerasUtil.serialize_model(model)
Ejemplo n.º 3
0
    def read_serialized_keras_model(self, ckpt_path, model, custom_objects):
        """Reads the checkpoint file of the keras model into model bytes and returns the base 64
        encoded model bytes.
        :param ckpt_path: A string of path to the checkpoint file.
        :param model: A keras model. This parameter will be used in DBFSLocalStore\
            .read_serialized_keras_model() when the ckpt_path only contains model weights.
        :param custom_objects: This parameter will be used in DBFSLocalStore\
            .read_serialized_keras_model() when loading the keras model.
        :return: the base 64 encoded model bytes of the checkpoint model.
        """
        from horovod.runner.common.util import codec
        import tensorflow
        from tensorflow import keras
        from horovod.spark.keras.util import TFKerasUtil

        if LooseVersion(tensorflow.__version__) < LooseVersion("2.0.0"):
            model_bytes = self.read(ckpt_path)
            return codec.dumps_base64(model_bytes)
        else:
            with keras.utils.custom_object_scope(custom_objects):
                model = keras.models.load_model(ckpt_path)
            return TFKerasUtil.serialize_model(model)