示例#1
0
 def load_weights(self, filepath, by_name=False):
     if h5py is None:
         raise ImportError('`load_weights` requires h5py.')
     f = h5py.File(filepath, mode='r')
     if 'layer_names' not in f.attrs and 'model_weights' in f:
         f = f['model_weights']
     layers = self.layers
     if by_name:
         topology.load_weights_from_hdf5_group_by_name(f, layers)
     else:
         topology.load_weights_from_hdf5_group(f, layers)
     if hasattr(f, 'close'):
         f.close()
示例#2
0
 def load_weights(self, filepath, by_name=False):
   if h5py is None:
     raise ImportError('`load_weights` requires h5py.')
   f = h5py.File(filepath, mode='r')
   if 'layer_names' not in f.attrs and 'model_weights' in f:
     f = f['model_weights']
   layers = self.layers
   if by_name:
     topology.load_weights_from_hdf5_group_by_name(f, layers)
   else:
     topology.load_weights_from_hdf5_group(f, layers)
   if hasattr(f, 'close'):
     f.close()
示例#3
0
def load_model(filepath, custom_objects=None):
    """Loads a model saved via `save_model`.

  Arguments:
      filepath: String, path to the saved model.
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.

  Returns:
      A Keras model instance. If an optimizer was found
      as part of the saved model, the model is already
      compiled. Otherwise, the model is uncompiled and
      a warning will be displayed.

  Raises:
      ImportError: if h5py is not available.
      ValueError: In case of an invalid savefile.
  """
    if h5py is None:
        raise ImportError('`save_model` requires h5py.')

    if not custom_objects:
        custom_objects = {}

    def convert_custom_objects(obj):
        """Handles custom object lookup.

    Arguments:
        obj: object, dict, or list.

    Returns:
        The same structure, where occurences
            of a custom object name have been replaced
            with the custom object.
    """
        if isinstance(obj, list):
            deserialized = []
            for value in obj:
                if value in custom_objects:
                    deserialized.append(custom_objects[value])
                else:
                    deserialized.append(value)
            return deserialized
        if isinstance(obj, dict):
            deserialized = {}
            for key, value in obj.items():
                if value in custom_objects:
                    deserialized[key] = custom_objects[value]
                else:
                    deserialized[key] = value
            return deserialized
        if obj in custom_objects:
            return custom_objects[obj]
        return obj

    f = h5py.File(filepath, mode='r')

    # instantiate model
    model_config = f.attrs.get('model_config')
    if model_config is None:
        raise ValueError('No model found in config file.')
    model_config = json.loads(model_config.decode('utf-8'))
    model = model_from_config(model_config, custom_objects=custom_objects)

    # set weights
    topology.load_weights_from_hdf5_group(f['model_weights'], model.layers)

    # instantiate optimizer
    training_config = f.attrs.get('training_config')
    if training_config is None:
        warnings.warn('No training configuration found in save file: '
                      'the model was *not* compiled. Compile it manually.')
        f.close()
        return model
    training_config = json.loads(training_config.decode('utf-8'))
    optimizer_config = training_config['optimizer_config']
    optimizer = optimizers.deserialize(optimizer_config,
                                       custom_objects=custom_objects)

    # Recover loss functions and metrics.
    loss = convert_custom_objects(training_config['loss'])
    metrics = convert_custom_objects(training_config['metrics'])
    sample_weight_mode = training_config['sample_weight_mode']
    loss_weights = training_config['loss_weights']

    # Compile model.
    model.compile(optimizer=optimizer,
                  loss=loss,
                  metrics=metrics,
                  loss_weights=loss_weights,
                  sample_weight_mode=sample_weight_mode)

    # Set optimizer weights.
    if 'optimizer_weights' in f:
        # Build train function (to get weight updates).
        if isinstance(model, Sequential):
            model.model._make_train_function()
        else:
            model._make_train_function()
        optimizer_weights_group = f['optimizer_weights']
        optimizer_weight_names = [
            n.decode('utf8')
            for n in optimizer_weights_group.attrs['weight_names']
        ]
        optimizer_weight_values = [
            optimizer_weights_group[n] for n in optimizer_weight_names
        ]
        model.optimizer.set_weights(optimizer_weight_values)
    f.close()
    return model
示例#4
0
def load_model(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
  """Loads a model saved via `save_model`.

  Arguments:
      filepath: String, path to the saved model.
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.

  Returns:
      A Keras model instance. If an optimizer was found
      as part of the saved model, the model is already
      compiled. Otherwise, the model is uncompiled and
      a warning will be displayed. When `compile` is set
      to False, the compilation is omitted without any
      warning.

  Raises:
      ImportError: if h5py is not available.
      ValueError: In case of an invalid savefile.
  """
  if h5py is None:
    raise ImportError('`load_model` requires h5py.')

  if not custom_objects:
    custom_objects = {}

  def convert_custom_objects(obj):
    """Handles custom object lookup.

    Arguments:
        obj: object, dict, or list.

    Returns:
        The same structure, where occurrences
            of a custom object name have been replaced
            with the custom object.
    """
    if isinstance(obj, list):
      deserialized = []
      for value in obj:
        if value in custom_objects:
          deserialized.append(custom_objects[value])
        else:
          deserialized.append(value)
      return deserialized
    if isinstance(obj, dict):
      deserialized = {}
      for key, value in obj.items():
        deserialized[key] = []
        if isinstance(value, list):
          for element in value:
            if element in custom_objects:
              deserialized[key].append(custom_objects[element])
            else:
              deserialized[key].append(element)
        elif value in custom_objects:
          deserialized[key] = custom_objects[value]
        else:
          deserialized[key] = value
      return deserialized
    if obj in custom_objects:
      return custom_objects[obj]
    return obj

  f = h5py.File(filepath, mode='r')

  # instantiate model
  model_config = f.attrs.get('model_config')
  if model_config is None:
    raise ValueError('No model found in config file.')
  model_config = json.loads(model_config.decode('utf-8'))
  model = model_from_config(model_config, custom_objects=custom_objects)

  # set weights
  topology.load_weights_from_hdf5_group(f['model_weights'], model.layers)

  # Early return if compilation is not required.
  if not compile:
    f.close()
    return model

  # instantiate optimizer
  training_config = f.attrs.get('training_config')
  if training_config is None:
    logging.warning('No training configuration found in save file: '
                    'the model was *not* compiled. Compile it manually.')
    f.close()
    return model
  training_config = json.loads(training_config.decode('utf-8'))
  optimizer_config = training_config['optimizer_config']
  optimizer = optimizers.deserialize(
      optimizer_config, custom_objects=custom_objects)

  # Recover loss functions and metrics.
  loss = convert_custom_objects(training_config['loss'])
  metrics = convert_custom_objects(training_config['metrics'])
  sample_weight_mode = training_config['sample_weight_mode']
  loss_weights = training_config['loss_weights']

  # Compile model.
  model.compile(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      loss_weights=loss_weights,
      sample_weight_mode=sample_weight_mode)

  # Set optimizer weights.
  if 'optimizer_weights' in f:
    # Build train function (to get weight updates).
    if isinstance(model, Sequential):
      model.model._make_train_function()
    else:
      model._make_train_function()
    optimizer_weights_group = f['optimizer_weights']
    optimizer_weight_names = [
        n.decode('utf8') for n in optimizer_weights_group.attrs['weight_names']
    ]
    optimizer_weight_values = [
        optimizer_weights_group[n] for n in optimizer_weight_names
    ]
    model.optimizer.set_weights(optimizer_weight_values)
  f.close()
  return model