예제 #1
0
def model_to_estimator(keras_model=None,
                       keras_model_path=None,
                       custom_objects=None,
                       model_dir=None,
                       config=None):
    """Constructs an `Estimator` instance from given keras model.

  For usage example, please see
  @{$guide/estimators$creating_estimators_from_keras_models}.

  Args:
    keras_model: A compiled Keras model object. This argument is mutually
      exclusive with `keras_model_path`.
    keras_model_path: Path to a compiled Keras model saved on disk, in HDF5
      format, which can be generated with the `save()` method of a Keras model.
      This argument is mutually exclusive with `keras_model`.
    custom_objects: Dictionary for custom objects.
    model_dir: Directory to save `Estimator` model parameters, graph, summary
      files for TensorBoard, etc.
    config: `RunConfig` to config `Estimator`.

  Returns:
    An Estimator from given keras model.

  Raises:
    ValueError: if neither keras_model nor keras_model_path was given.
    ValueError: if both keras_model and keras_model_path was given.
    ValueError: if the keras_model_path is a GCS URI.
    ValueError: if keras_model has not been compiled.
  """
    if not (keras_model or keras_model_path):
        raise ValueError(
            'Either `keras_model` or `keras_model_path` needs to be provided.')
    if keras_model and keras_model_path:
        raise ValueError(
            'Please specity either `keras_model` or `keras_model_path`, '
            'but not both.')

    if not keras_model:
        if keras_model_path.startswith(
                'gs://') or 'storage.googleapis.com' in keras_model_path:
            raise ValueError(
                '%s is not a local path. Please copy the model locally first.'
                % keras_model_path)
        logging.info('Loading models from %s', keras_model_path)
        keras_model = models.load_model(keras_model_path)
    else:
        logging.info('Using the Keras model provided.')
        keras_model = keras_model

    if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
        raise ValueError('The given keras model has not been compiled yet. '
                         'Please compile the model with `model.compile()` '
                         'before calling `model_to_estimator()`.')

    config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
        config, model_dir)

    keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
    if _any_weight_initialized(keras_model):
        # Warn if config passed to estimator tries to update GPUOptions. If a
        # session has already been created, the GPUOptions passed to the first
        # session sticks.
        if config.session_config.HasField('gpu_options'):
            logging.warning(
                'The Keras backend session has already been set. '
                'The _session_config passed to model_to_estimator will not be used.'
            )
    else:
        # Pass the config into keras backend's default session.
        sess = session.Session(config=config.session_config)
        K.set_session(sess)

    warm_start_path = None
    if keras_model._is_graph_network:
        warm_start_path = _save_first_checkpoint(keras_model, custom_objects,
                                                 config)
    elif keras_model.built:
        logging.warning(
            'You are creating an Estimator from a Keras model manually '
            'subclassed from `Model`, that was already called on some '
            'inputs (and thus already had weights). We are currently '
            'unable to preserve the model\'s state (its weights) as '
            'part of the estimator in this case. Be warned that the '
            'estimator has been created using a freshly initialized '
            'version of your model.\n'
            'Note that this doesn\'t affect the state of the model '
            'instance you passed as `keras_model` argument.')

    estimator = estimator_lib.Estimator(keras_model_fn,
                                        config=config,
                                        warm_start_from=warm_start_path)

    return estimator
예제 #2
0
def model_to_estimator(keras_model=None,
                       keras_model_path=None,
                       custom_objects=None,
                       model_dir=None,
                       config=None):
  """Constructs an `Estimator` instance from given keras model.

  For usage example, please see:
  [Creating estimators from Keras
  Models](https://tensorflow.org/guide/estimators#model_to_estimator).

  Args:
    keras_model: A compiled Keras model object. This argument is mutually
      exclusive with `keras_model_path`.
    keras_model_path: Path to a compiled Keras model saved on disk, in HDF5
      format, which can be generated with the `save()` method of a Keras model.
      This argument is mutually exclusive with `keras_model`.
    custom_objects: Dictionary for custom objects.
    model_dir: Directory to save `Estimator` model parameters, graph, summary
      files for TensorBoard, etc.
    config: `RunConfig` to config `Estimator`.

  Returns:
    An Estimator from given keras model.

  Raises:
    ValueError: if neither keras_model nor keras_model_path was given.
    ValueError: if both keras_model and keras_model_path was given.
    ValueError: if the keras_model_path is a GCS URI.
    ValueError: if keras_model has not been compiled.
  """
  if not (keras_model or keras_model_path):
    raise ValueError(
        'Either `keras_model` or `keras_model_path` needs to be provided.')
  if keras_model and keras_model_path:
    raise ValueError(
        'Please specity either `keras_model` or `keras_model_path`, '
        'but not both.')

  if not keras_model:
    if keras_model_path.startswith(
        'gs://') or 'storage.googleapis.com' in keras_model_path:
      raise ValueError(
          '%s is not a local path. Please copy the model locally first.' %
          keras_model_path)
    logging.info('Loading models from %s', keras_model_path)
    keras_model = models.load_model(keras_model_path)
  else:
    logging.info('Using the Keras model provided.')
    keras_model = keras_model

  if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
    raise ValueError(
        'The given keras model has not been compiled yet. '
        'Please compile the model with `model.compile()` '
        'before calling `model_to_estimator()`.')

  config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config,
                                                                      model_dir)

  keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
  if _any_weight_initialized(keras_model):
    # Warn if config passed to estimator tries to update GPUOptions. If a
    # session has already been created, the GPUOptions passed to the first
    # session sticks.
    if config.session_config.HasField('gpu_options'):
      logging.warning(
          'The Keras backend session has already been set. '
          'The _session_config passed to model_to_estimator will not be used.')
  else:
    # Pass the config into keras backend's default session.
    sess = session.Session(config=config.session_config)
    K.set_session(sess)

  warm_start_path = None
  if keras_model._is_graph_network:
    warm_start_path = _save_first_checkpoint(keras_model, custom_objects,
                                             config)
  elif keras_model.built:
    logging.warning('You are creating an Estimator from a Keras model manually '
                    'subclassed from `Model`, that was already called on some '
                    'inputs (and thus already had weights). We are currently '
                    'unable to preserve the model\'s state (its weights) as '
                    'part of the estimator in this case. Be warned that the '
                    'estimator has been created using a freshly initialized '
                    'version of your model.\n'
                    'Note that this doesn\'t affect the state of the model '
                    'instance you passed as `keras_model` argument.')

  estimator = estimator_lib.Estimator(keras_model_fn,
                                      config=config,
                                      warm_start_from=warm_start_path)

  return estimator