def load_weights(self, filepath, by_name=False): if h5py is None: raise ImportError('`load_weights` requires h5py.') f = h5py.File(filepath, mode='r') if 'layer_names' not in f.attrs and 'model_weights' in f: f = f['model_weights'] layers = self.layers if by_name: topology.load_weights_from_hdf5_group_by_name(f, layers) else: topology.load_weights_from_hdf5_group(f, layers) if hasattr(f, 'close'): f.close()
def load_model(filepath, custom_objects=None): """Loads a model saved via `save_model`. Arguments: filepath: String, path to the saved model. custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. Returns: A Keras model instance. If an optimizer was found as part of the saved model, the model is already compiled. Otherwise, the model is uncompiled and a warning will be displayed. Raises: ImportError: if h5py is not available. ValueError: In case of an invalid savefile. """ if h5py is None: raise ImportError('`save_model` requires h5py.') if not custom_objects: custom_objects = {} def convert_custom_objects(obj): """Handles custom object lookup. Arguments: obj: object, dict, or list. Returns: The same structure, where occurences of a custom object name have been replaced with the custom object. """ if isinstance(obj, list): deserialized = [] for value in obj: if value in custom_objects: deserialized.append(custom_objects[value]) else: deserialized.append(value) return deserialized if isinstance(obj, dict): deserialized = {} for key, value in obj.items(): if value in custom_objects: deserialized[key] = custom_objects[value] else: deserialized[key] = value return deserialized if obj in custom_objects: return custom_objects[obj] return obj f = h5py.File(filepath, mode='r') # instantiate model model_config = f.attrs.get('model_config') if model_config is None: raise ValueError('No model found in config file.') model_config = json.loads(model_config.decode('utf-8')) model = model_from_config(model_config, custom_objects=custom_objects) # set weights topology.load_weights_from_hdf5_group(f['model_weights'], model.layers) # instantiate optimizer training_config = f.attrs.get('training_config') if training_config is None: warnings.warn('No training configuration found in save file: ' 'the model was *not* compiled. Compile it manually.') f.close() return model training_config = json.loads(training_config.decode('utf-8')) optimizer_config = training_config['optimizer_config'] optimizer = optimizers.deserialize(optimizer_config, custom_objects=custom_objects) # Recover loss functions and metrics. loss = convert_custom_objects(training_config['loss']) metrics = convert_custom_objects(training_config['metrics']) sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] # Compile model. model.compile(optimizer=optimizer, loss=loss, metrics=metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode) # Set optimizer weights. if 'optimizer_weights' in f: # Build train function (to get weight updates). if isinstance(model, Sequential): model.model._make_train_function() else: model._make_train_function() optimizer_weights_group = f['optimizer_weights'] optimizer_weight_names = [ n.decode('utf8') for n in optimizer_weights_group.attrs['weight_names'] ] optimizer_weight_values = [ optimizer_weights_group[n] for n in optimizer_weight_names ] model.optimizer.set_weights(optimizer_weight_values) f.close() return model
def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin """Loads a model saved via `save_model`. Arguments: filepath: String, path to the saved model. custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. compile: Boolean, whether to compile the model after loading. Returns: A Keras model instance. If an optimizer was found as part of the saved model, the model is already compiled. Otherwise, the model is uncompiled and a warning will be displayed. When `compile` is set to False, the compilation is omitted without any warning. Raises: ImportError: if h5py is not available. ValueError: In case of an invalid savefile. """ if h5py is None: raise ImportError('`load_model` requires h5py.') if not custom_objects: custom_objects = {} def convert_custom_objects(obj): """Handles custom object lookup. Arguments: obj: object, dict, or list. Returns: The same structure, where occurrences of a custom object name have been replaced with the custom object. """ if isinstance(obj, list): deserialized = [] for value in obj: if value in custom_objects: deserialized.append(custom_objects[value]) else: deserialized.append(value) return deserialized if isinstance(obj, dict): deserialized = {} for key, value in obj.items(): deserialized[key] = [] if isinstance(value, list): for element in value: if element in custom_objects: deserialized[key].append(custom_objects[element]) else: deserialized[key].append(element) elif value in custom_objects: deserialized[key] = custom_objects[value] else: deserialized[key] = value return deserialized if obj in custom_objects: return custom_objects[obj] return obj f = h5py.File(filepath, mode='r') # instantiate model model_config = f.attrs.get('model_config') if model_config is None: raise ValueError('No model found in config file.') model_config = json.loads(model_config.decode('utf-8')) model = model_from_config(model_config, custom_objects=custom_objects) # set weights topology.load_weights_from_hdf5_group(f['model_weights'], model.layers) # Early return if compilation is not required. if not compile: f.close() return model # instantiate optimizer training_config = f.attrs.get('training_config') if training_config is None: logging.warning('No training configuration found in save file: ' 'the model was *not* compiled. Compile it manually.') f.close() return model training_config = json.loads(training_config.decode('utf-8')) optimizer_config = training_config['optimizer_config'] optimizer = optimizers.deserialize( optimizer_config, custom_objects=custom_objects) # Recover loss functions and metrics. loss = convert_custom_objects(training_config['loss']) metrics = convert_custom_objects(training_config['metrics']) sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] # Compile model. model.compile( optimizer=optimizer, loss=loss, metrics=metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode) # Set optimizer weights. if 'optimizer_weights' in f: # Build train function (to get weight updates). if isinstance(model, Sequential): model.model._make_train_function() else: model._make_train_function() optimizer_weights_group = f['optimizer_weights'] optimizer_weight_names = [ n.decode('utf8') for n in optimizer_weights_group.attrs['weight_names'] ] optimizer_weight_values = [ optimizer_weights_group[n] for n in optimizer_weight_names ] model.optimizer.set_weights(optimizer_weight_values) f.close() return model