def _pack_obj(obj): """Recursively packs objects. """ try: return copy.deepcopy(obj) except TypeError: pass # is this a Keras serializable? try: model_metadata = saving_utils.model_metadata(obj) training_config = model_metadata["training_config"] model = serialize(obj) weights = obj.get_weights() return SavedKerasModel( cls=obj.__class__, model=model, weights=weights, training_config=training_config, ) except (TypeError, AttributeError): pass # try manually packing the object if hasattr(obj, "__dict__"): for key, val in obj.__dict__.items(): obj.__dict__[key] = _pack_obj(val) return obj if isinstance(obj, (list, tuple)): obj_type = type(obj) new_obj = obj_type([_pack_obj(o) for o in obj]) return new_obj return obj
def pack_keras_model(model_obj, protocol): """Pickle a Keras Model. Arguments: model_obj: an instance of a Keras Model. protocol: pickle protocol version, ignored. Returns ------- Pickled model A tuple following the pickle protocol. """ if not isinstance(model_obj, Model): raise TypeError("`model_obj` must be an instance of a Keras Model") # pack up model model_metadata = saving_utils.model_metadata(model_obj) training_config = model_metadata.get("training_config", None) model = serialize(model_obj) weights = model_obj.get_weights() return (unpack_keras_model, (model, training_config, weights))
def __reduce__(self): model_metadata = saving_utils.model_metadata(self) training_config = model_metadata.get("training_config", None) model = serialize(self) weights = self.get_weights() return (unpack, (model, training_config, weights))
def _reduce_tf_model(model): model_metadata = saving_utils.model_metadata(model) training_config = model_metadata.get("training_config", None) weights = model.get_weights() model = serialize(model) return TensorflowDispatcher._make_model, (model, training_config, weights)