Ejemplo n.º 1
0
 def save_weights(self, filepath, overwrite=True):
     if h5py is None:
         raise ImportError('`save_weights` requires h5py.')
     # If file exists and should not be overwritten:
     if not overwrite and os.path.isfile(filepath):
         proceed = ask_to_proceed_with_overwrite(filepath)
         if not proceed:
             return
     layers = self.layers
     f = h5py.File(filepath, 'w')
     topology.save_weights_to_hdf5_group(f, layers)
     f.flush()
     f.close()
Ejemplo n.º 2
0
 def save_weights(self, filepath, overwrite=True):
   if h5py is None:
     raise ImportError('`save_weights` requires h5py.')
   # If file exists and should not be overwritten:
   if not overwrite and os.path.isfile(filepath):
     proceed = ask_to_proceed_with_overwrite(filepath)
     if not proceed:
       return
   layers = self.layers
   f = h5py.File(filepath, 'w')
   topology.save_weights_to_hdf5_group(f, layers)
   f.flush()
   f.close()
Ejemplo n.º 3
0
def save_model(model, filepath, overwrite=True):
    """Save a model to a HDF5 file.

  The saved model contains:
      - the model's configuration (topology)
      - the model's weights
      - the model's optimizer's state (if any)

  Thus the saved model can be reinstantiated in
  the exact same state, without any of the code
  used for model definition or training.

  Arguments:
      model: Keras model instance to be saved.
      filepath: String, path where to save the model.
      overwrite: Whether we should overwrite any existing
          model at the target location, or instead
          ask the user with a manual prompt.

  Raises:
      ImportError: if h5py is not available.
  """

    if h5py is None:
        raise ImportError('`save_model` requires h5py.')

    def get_json_type(obj):
        """Serialize any object to a JSON-serializable structure.

    Arguments:
        obj: the object to serialize

    Returns:
        JSON-serializable structure representing `obj`.

    Raises:
        TypeError: if `obj` cannot be serialized.
    """
        # if obj is a serializable Keras class instance
        # e.g. optimizer, layer
        if hasattr(obj, 'get_config'):
            return {
                'class_name': obj.__class__.__name__,
                'config': obj.get_config()
            }

        # if obj is any numpy type
        if type(obj).__module__ == np.__name__:
            return obj.item()

        # misc functions (e.g. loss function)
        if callable(obj):
            return obj.__name__

        # if obj is a python 'type'
        if type(obj).__name__ == type.__name__:
            return obj.__name__

        raise TypeError('Not JSON Serializable:', obj)

    from tensorflow.contrib.keras.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top

    # If file exists and should not be overwritten.
    if not overwrite and os.path.isfile(filepath):
        proceed = ask_to_proceed_with_overwrite(filepath)
        if not proceed:
            return

    f = h5py.File(filepath, 'w')
    f.attrs['keras_version'] = str(keras_version).encode('utf8')
    f.attrs['backend'] = K.backend().encode('utf8')
    f.attrs['model_config'] = json.dumps(
        {
            'class_name': model.__class__.__name__,
            'config': model.get_config()
        },
        default=get_json_type).encode('utf8')

    model_weights_group = f.create_group('model_weights')
    model_layers = model.layers
    topology.save_weights_to_hdf5_group(model_weights_group, model_layers)

    if hasattr(model, 'optimizer'):
        if isinstance(model.optimizer, optimizers.TFOptimizer):
            warnings.warn(
                'TensorFlow optimizers do not '
                'make it possible to access '
                'optimizer attributes or optimizer state '
                'after instantiation. '
                'As a result, we cannot save the optimizer '
                'as part of the model save file.'
                'You will have to compile your model again after loading it. '
                'Prefer using a Keras optimizer instead '
                '(see keras.io/optimizers).')
        else:
            f.attrs['training_config'] = json.dumps(
                {
                    'optimizer_config': {
                        'class_name': model.optimizer.__class__.__name__,
                        'config': model.optimizer.get_config()
                    },
                    'loss': model.loss,
                    'metrics': model.metrics,
                    'sample_weight_mode': model.sample_weight_mode,
                    'loss_weights': model.loss_weights,
                },
                default=get_json_type).encode('utf8')

            # Save optimizer weights.
            symbolic_weights = getattr(model.optimizer, 'weights')
            if symbolic_weights:
                optimizer_weights_group = f.create_group('optimizer_weights')
                weight_values = K.batch_get_value(symbolic_weights)
                weight_names = []
                for i, (w,
                        val) in enumerate(zip(symbolic_weights,
                                              weight_values)):
                    # Default values of symbolic_weights is /variable for theano
                    if K.backend() == 'theano':
                        if hasattr(w, 'name') and w.name != '/variable':
                            name = str(w.name)
                        else:
                            name = 'param_' + str(i)
                    else:
                        if hasattr(w, 'name') and w.name:
                            name = str(w.name)
                        else:
                            name = 'param_' + str(i)
                    weight_names.append(name.encode('utf8'))
                optimizer_weights_group.attrs['weight_names'] = weight_names
                for name, val in zip(weight_names, weight_values):
                    param_dset = optimizer_weights_group.create_dataset(
                        name, val.shape, dtype=val.dtype)
                    if not val.shape:
                        # scalar
                        param_dset[()] = val
                    else:
                        param_dset[:] = val
    f.flush()
    f.close()
Ejemplo n.º 4
0
def save_model(model, filepath, overwrite=True, include_optimizer=True):
  """Save a model to a HDF5 file.

  The saved model contains:
      - the model's configuration (topology)
      - the model's weights
      - the model's optimizer's state (if any)

  Thus the saved model can be reinstantiated in
  the exact same state, without any of the code
  used for model definition or training.

  Arguments:
      model: Keras model instance to be saved.
      filepath: String, path where to save the model.
      overwrite: Whether we should overwrite any existing
          model at the target location, or instead
          ask the user with a manual prompt.
      include_optimizer: If True, save optimizer's state together.

  Raises:
      ImportError: if h5py is not available.
  """

  if h5py is None:
    raise ImportError('`save_model` requires h5py.')

  def get_json_type(obj):
    """Serialize any object to a JSON-serializable structure.

    Arguments:
        obj: the object to serialize

    Returns:
        JSON-serializable structure representing `obj`.

    Raises:
        TypeError: if `obj` cannot be serialized.
    """
    # if obj is a serializable Keras class instance
    # e.g. optimizer, layer
    if hasattr(obj, 'get_config'):
      return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}

    # if obj is any numpy type
    if type(obj).__module__ == np.__name__:
      return obj.item()

    # misc functions (e.g. loss function)
    if callable(obj):
      return obj.__name__

    # if obj is a python 'type'
    if type(obj).__name__ == type.__name__:
      return obj.__name__

    raise TypeError('Not JSON Serializable:', obj)

  from tensorflow.contrib.keras.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top

  # If file exists and should not be overwritten.
  if not overwrite and os.path.isfile(filepath):
    proceed = ask_to_proceed_with_overwrite(filepath)
    if not proceed:
      return

  f = h5py.File(filepath, 'w')
  f.attrs['keras_version'] = str(keras_version).encode('utf8')
  f.attrs['backend'] = K.backend().encode('utf8')
  f.attrs['model_config'] = json.dumps(
      {
          'class_name': model.__class__.__name__,
          'config': model.get_config()
      },
      default=get_json_type).encode('utf8')

  model_weights_group = f.create_group('model_weights')
  model_layers = model.layers
  topology.save_weights_to_hdf5_group(model_weights_group, model_layers)

  if include_optimizer and hasattr(model, 'optimizer'):
    if isinstance(model.optimizer, optimizers.TFOptimizer):
      logging.warning(
          'TensorFlow optimizers do not '
          'make it possible to access '
          'optimizer attributes or optimizer state '
          'after instantiation. '
          'As a result, we cannot save the optimizer '
          'as part of the model save file.'
          'You will have to compile your model again after loading it. '
          'Prefer using a Keras optimizer instead '
          '(see keras.io/optimizers).')
    else:
      f.attrs['training_config'] = json.dumps(
          {
              'optimizer_config': {
                  'class_name': model.optimizer.__class__.__name__,
                  'config': model.optimizer.get_config()
              },
              'loss': model.loss,
              'metrics': model.metrics,
              'sample_weight_mode': model.sample_weight_mode,
              'loss_weights': model.loss_weights,
          },
          default=get_json_type).encode('utf8')

      # Save optimizer weights.
      symbolic_weights = getattr(model.optimizer, 'weights')
      if symbolic_weights:
        optimizer_weights_group = f.create_group('optimizer_weights')
        weight_values = K.batch_get_value(symbolic_weights)
        weight_names = []
        for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
          # Default values of symbolic_weights is /variable for theano
          if K.backend() == 'theano':
            if hasattr(w, 'name') and w.name != '/variable':
              name = str(w.name)
            else:
              name = 'param_' + str(i)
          else:
            if hasattr(w, 'name') and w.name:
              name = str(w.name)
            else:
              name = 'param_' + str(i)
          weight_names.append(name.encode('utf8'))
        optimizer_weights_group.attrs['weight_names'] = weight_names
        for name, val in zip(weight_names, weight_values):
          param_dset = optimizer_weights_group.create_dataset(
              name, val.shape, dtype=val.dtype)
          if not val.shape:
            # scalar
            param_dset[()] = val
          else:
            param_dset[:] = val
  f.flush()
  f.close()