コード例 #1
0
  def _python_properties_internal(self):
    metadata = super(ModelSavedModelSaver, self)._python_properties_internal()
    # Network stateful property is dependent on the child layers.
    metadata.pop('stateful')
    metadata['is_graph_network'] = self.obj._is_graph_network  # pylint: disable=protected-access
    metadata['save_spec'] = self.obj._get_save_spec(dynamic_batch=False)  # pylint: disable=protected-access

    metadata.update(
        saving_utils.model_metadata(
            self.obj, include_optimizer=True, require_config=False))
    return metadata
コード例 #2
0
    def _python_properties_internal(self):
        metadata = super()._python_properties_internal()
        # Network stateful property is dependent on the child layers.
        metadata.pop('stateful')
        metadata['is_graph_network'] = self.obj._is_graph_network  # pylint: disable=protected-access
        spec = self.obj.save_spec(dynamic_batch=False)
        metadata['full_save_spec'] = spec
        # save_spec is saved for forward compatibility on older TF versions.
        metadata['save_spec'] = None if spec is None else spec[0][0]

        metadata.update(
            saving_utils.model_metadata(self.obj,
                                        include_optimizer=True,
                                        require_config=False))
        return metadata
コード例 #3
0
ファイル: hdf5_format.py プロジェクト: qlzh727/keras
def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
  """Saves a model to a HDF5 file.

  The saved model contains:
      - the model's configuration (topology)
      - the model's weights
      - the model's optimizer's state (if any)

  Thus the saved model can be reinstantiated in
  the exact same state, without any of the code
  used for model definition or training.

  Args:
      model: Keras model instance to be saved.
      filepath: One of the following:
          - String, path where to save the model
          - `h5py.File` object where to save the model
      overwrite: Whether we should overwrite any existing
          model at the target location, or instead
          ask the user with a manual prompt.
      include_optimizer: If True, save optimizer's state together.

  Raises:
      ImportError: if h5py is not available.
  """

  if h5py is None:
    raise ImportError('`save_model()` using h5 format requires h5py. Could not '
                      'import h5py.')

  # TODO(psv) Add warning when we save models that contain non-serializable
  # entities like metrics added using `add_metric` and losses added using
  # `add_loss.`
  if len(model.weights) != len(model._undeduplicated_weights):
    logging.warning('Found duplicated `Variable`s in Model\'s `weights`. '
                    'This is usually caused by `Variable`s being shared by '
                    'Layers in the Model. These `Variable`s will be treated '
                    'as separate `Variable`s when the Model is restored. To '
                    'avoid this, please save with `save_format="tf"`.')

  if not isinstance(filepath, h5py.File):
    # If file exists and should not be overwritten.
    if not overwrite and os.path.isfile(filepath):
      proceed = ask_to_proceed_with_overwrite(filepath)
      if not proceed:
        return

    # Try creating dir if not exist
    dirpath = os.path.dirname(filepath)
    if not os.path.exists(dirpath):
      tf.io.gfile.makedirs(dirpath)

    f = h5py.File(filepath, mode='w')
    opened_new_file = True
  else:
    f = filepath
    opened_new_file = False

  try:
    model_metadata = saving_utils.model_metadata(model, include_optimizer)
    for k, v in model_metadata.items():
      if isinstance(v, (dict, list, tuple)):
        f.attrs[k] = json.dumps(
            v, default=json_utils.get_json_type).encode('utf8')
      else:
        f.attrs[k] = v

    model_weights_group = f.create_group('model_weights')
    save_weights_to_hdf5_group(model_weights_group, model)

    # TODO(b/128683857): Add integration tests between tf.keras and external
    # Keras, to avoid breaking TF.js users.
    if (include_optimizer and model.optimizer and
        not isinstance(model.optimizer, optimizer_v1.TFOptimizer)):
      save_optimizer_weights_to_hdf5_group(f, model.optimizer)

    f.flush()
  finally:
    if opened_new_file:
      f.close()