示例#1
0
        def _pack_obj(obj):
            """Recursively packs objects.
            """
            try:
                return copy.deepcopy(obj)
            except TypeError:
                pass  # is this a Keras serializable?
            try:
                model_metadata = saving_utils.model_metadata(obj)
                training_config = model_metadata["training_config"]
                model = serialize(obj)
                weights = obj.get_weights()
                return SavedKerasModel(
                    cls=obj.__class__,
                    model=model,
                    weights=weights,
                    training_config=training_config,
                )
            except (TypeError, AttributeError):
                pass  # try manually packing the object
            if hasattr(obj, "__dict__"):
                for key, val in obj.__dict__.items():
                    obj.__dict__[key] = _pack_obj(val)
                return obj
            if isinstance(obj, (list, tuple)):
                obj_type = type(obj)
                new_obj = obj_type([_pack_obj(o) for o in obj])
                return new_obj

            return obj
示例#2
0
def pack_keras_model(model_obj, protocol):
    """Pickle a Keras Model.

    Arguments:
        model_obj: an instance of a Keras Model.
        protocol: pickle protocol version, ignored.

    Returns
    -------
    Pickled model
        A tuple following the pickle protocol.
    """
    if not isinstance(model_obj, Model):
        raise TypeError("`model_obj` must be an instance of a Keras Model")
    # pack up model
    model_metadata = saving_utils.model_metadata(model_obj)
    training_config = model_metadata.get("training_config", None)
    model = serialize(model_obj)
    weights = model_obj.get_weights()
    return (unpack_keras_model, (model, training_config, weights))
 def __reduce__(self):
     model_metadata = saving_utils.model_metadata(self)
     training_config = model_metadata.get("training_config", None)
     model = serialize(self)
     weights = self.get_weights()
     return (unpack, (model, training_config, weights))
示例#4
0
 def _reduce_tf_model(model):
     model_metadata = saving_utils.model_metadata(model)
     training_config = model_metadata.get("training_config", None)
     weights = model.get_weights()
     model = serialize(model)
     return TensorflowDispatcher._make_model, (model, training_config, weights)