def compile_args_from_training_config(training_config, custom_objects=None): """Return model.compile arguments from training config.""" if custom_objects is None: custom_objects = {} optimizer_config = training_config['optimizer_config'] optimizer = optimizers.deserialize( optimizer_config, custom_objects=custom_objects) # Recover loss functions and metrics. loss_config = training_config['loss'] # Deserialize loss class. if isinstance(loss_config, dict) and 'class_name' in loss_config: loss_config = losses.get(loss_config) loss = nest.map_structure( lambda obj: custom_objects.get(obj, obj), loss_config) metrics = nest.map_structure( lambda obj: custom_objects.get(obj, obj), training_config['metrics']) weighted_metrics = nest.map_structure( lambda obj: custom_objects.get(obj, obj), training_config.get('weighted_metrics', None)) sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] return dict( optimizer=optimizer, loss=loss, metrics=metrics, weighted_metrics=weighted_metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode)
def get_loss_function(loss): """Returns the loss function corresponding to the given loss input.""" if loss is None or isinstance(loss, losses.Loss): return loss # TODO(psv): After we have added all V2 losses, update this function. if loss in ['mse', 'MSE', 'mean_squared_error']: return losses.MeanSquaredError() return losses.get(loss)
def get_loss_function(loss): """Returns the loss function corresponding to the given loss input.""" if loss is None or isinstance(loss, losses.Loss): return loss # TODO(psv): After we have added all V2 losses, update this function. if loss in ['mse', 'MSE', 'mean_squared_error']: return losses.MeanSquaredError() return losses.get(loss)
def _get_loss_object(self, loss): """Returns a `Loss` object. Converts the user-supplied loss to a `Loss` object. Also allows `SUM_OVER_BATCH_SIZE` reduction to be used for this loss. Arguments: loss: A string, function, or `Loss` object. Returns: A `Loss` object. """ if loss is None: return None # Ok to have no loss for an output. loss = losses_mod.get(loss) if not isinstance(loss, losses_mod.Loss): loss = losses_mod.LossFunctionWrapper(loss, name=loss.__name__) loss._allow_sum_over_batch_size = True # pylint: disable=protected-access return loss
def _get_loss_object(self, loss): """Returns a `Loss` object. Converts the user-supplied loss to a `Loss` object. Also allows `SUM_OVER_BATCH_SIZE` reduction to be used for this loss. Args: loss: A string, function, or `Loss` object. Returns: A `Loss` object. """ if loss is None: return None # Ok to have no loss for an output. loss = losses_mod.get(loss) if not isinstance(loss, losses_mod.Loss): loss_name = get_custom_object_name(loss) if loss_name is None: raise ValueError('Loss should be a callable, found: {}'.format(loss)) loss = losses_mod.LossFunctionWrapper(loss, name=loss_name) loss._allow_sum_over_batch_size = True # pylint: disable=protected-access return loss
def _get_loss_object(self, loss): """Returns a `Loss` object. Converts the user-supplied loss to a `Loss` object. Also allows `SUM_OVER_BATCH_SIZE` reduction to be used for this loss. Arguments: loss: A string, function, or `Loss` object. Returns: A `Loss` object. """ if loss is None: return None # Ok to have no loss for an output. # TODO(omalleyt): Handle special casing for crossentropy. loss = losses_mod.get(loss) if not isinstance(loss, losses_mod.Loss): loss = losses_mod.LossFunctionWrapper(loss) # Allow AUTO and SUM_OVER_BATCH_SIZE reductions. # TODO(omalleyt): Can we reconcile CTL and built-in loss reductions? loss._allow_sum_over_batch_size = True # pylint: disable=protected-access return loss
def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin """Loads a model saved via `save_model_to_hdf5`. Arguments: filepath: One of the following: - String, path to the saved model - `h5py.File` object from which to load the model custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. compile: Boolean, whether to compile the model after loading. Returns: A Keras model instance. If an optimizer was found as part of the saved model, the model is already compiled. Otherwise, the model is uncompiled and a warning will be displayed. When `compile` is set to False, the compilation is omitted without any warning. Raises: ImportError: if h5py is not available. ValueError: In case of an invalid savefile. """ if h5py is None: raise ImportError('`load_model` requires h5py.') if not custom_objects: custom_objects = {} def convert_custom_objects(obj): """Handles custom object lookup. Arguments: obj: object, dict, or list. Returns: The same structure, where occurrences of a custom object name have been replaced with the custom object. """ if isinstance(obj, list): deserialized = [] for value in obj: deserialized.append(convert_custom_objects(value)) return deserialized if isinstance(obj, dict): deserialized = {} for key, value in obj.items(): deserialized[key] = convert_custom_objects(value) return deserialized if obj in custom_objects: return custom_objects[obj] return obj opened_new_file = not isinstance(filepath, h5py.File) if opened_new_file: f = h5py.File(filepath, mode='r') else: f = filepath model = None try: # instantiate model model_config = f.attrs.get('model_config') if model_config is None: raise ValueError('No model found in config file.') model_config = json.loads(model_config.decode('utf-8')) model = model_config_lib.model_from_config( model_config, custom_objects=custom_objects) # set weights load_weights_from_hdf5_group(f['model_weights'], model.layers) if compile: # instantiate optimizer training_config = f.attrs.get('training_config') if training_config is None: logging.warning( 'No training configuration found in save file: ' 'the model was *not* compiled. Compile it manually.') return model training_config = json.loads(training_config.decode('utf-8')) optimizer_config = training_config['optimizer_config'] optimizer = optimizers.deserialize(optimizer_config, custom_objects=custom_objects) # Recover loss functions and metrics. loss_config = training_config['loss'] # Deserialize loss class. if isinstance(loss_config, dict) and 'class_name' in loss_config: loss_config = losses.get(loss_config) loss = convert_custom_objects(loss_config) metrics = convert_custom_objects(training_config['metrics']) weighted_metrics = convert_custom_objects( training_config.get('weighted_metrics', None)) sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] # Compile model. model.compile(optimizer=optimizer, loss=loss, metrics=metrics, weighted_metrics=weighted_metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode) # Set optimizer weights. if 'optimizer_weights' in f: # Build train function (to get weight updates). # Models that aren't graph networks must wait until they are called # with data to _make_train_function() and so can't load optimizer # weights. if model._is_graph_network: # pylint: disable=protected-access model._make_train_function() optimizer_weight_values = load_optimizer_weights_from_hdf5_group( f) try: model.optimizer.set_weights(optimizer_weight_values) except ValueError: logging.warning('Error in loading the saved optimizer ' 'state. As a result, your model is ' 'starting with a freshly initialized ' 'optimizer.') else: logging.warning( 'Sequential models without an `input_shape` ' 'passed to the first layer cannot reload their ' 'optimizer state. As a result, your model is' 'starting with a freshly initialized optimizer.') finally: if opened_new_file: f.close() return model