def model_to_estimator(keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None, config=None): """Constructs an `Estimator` instance from given keras model. For usage example, please see @{$guide/estimators$creating_estimators_from_keras_models}. Args: keras_model: A compiled Keras model object. This argument is mutually exclusive with `keras_model_path`. keras_model_path: Path to a compiled Keras model saved on disk, in HDF5 format, which can be generated with the `save()` method of a Keras model. This argument is mutually exclusive with `keras_model`. custom_objects: Dictionary for custom objects. model_dir: Directory to save `Estimator` model parameters, graph, summary files for TensorBoard, etc. config: `RunConfig` to config `Estimator`. Returns: An Estimator from given keras model. Raises: ValueError: if neither keras_model nor keras_model_path was given. ValueError: if both keras_model and keras_model_path was given. ValueError: if the keras_model_path is a GCS URI. ValueError: if keras_model has not been compiled. """ if not (keras_model or keras_model_path): raise ValueError( 'Either `keras_model` or `keras_model_path` needs to be provided.') if keras_model and keras_model_path: raise ValueError( 'Please specity either `keras_model` or `keras_model_path`, ' 'but not both.') if not keras_model: if keras_model_path.startswith( 'gs://') or 'storage.googleapis.com' in keras_model_path: raise ValueError( '%s is not a local path. Please copy the model locally first.' % keras_model_path) logging.info('Loading models from %s', keras_model_path) keras_model = models.load_model(keras_model_path) else: logging.info('Using the Keras model provided.') keras_model = keras_model if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer: raise ValueError('The given keras model has not been compiled yet. ' 'Please compile the model with `model.compile()` ' 'before calling `model_to_estimator()`.') config = estimator_lib.maybe_overwrite_model_dir_and_session_config( config, model_dir) keras_model_fn = _create_keras_model_fn(keras_model, custom_objects) if _any_weight_initialized(keras_model): # Warn if config passed to estimator tries to update GPUOptions. If a # session has already been created, the GPUOptions passed to the first # session sticks. if config.session_config.HasField('gpu_options'): logging.warning( 'The Keras backend session has already been set. ' 'The _session_config passed to model_to_estimator will not be used.' ) else: # Pass the config into keras backend's default session. sess = session.Session(config=config.session_config) K.set_session(sess) warm_start_path = None if keras_model._is_graph_network: warm_start_path = _save_first_checkpoint(keras_model, custom_objects, config) elif keras_model.built: logging.warning( 'You are creating an Estimator from a Keras model manually ' 'subclassed from `Model`, that was already called on some ' 'inputs (and thus already had weights). We are currently ' 'unable to preserve the model\'s state (its weights) as ' 'part of the estimator in this case. Be warned that the ' 'estimator has been created using a freshly initialized ' 'version of your model.\n' 'Note that this doesn\'t affect the state of the model ' 'instance you passed as `keras_model` argument.') estimator = estimator_lib.Estimator(keras_model_fn, config=config, warm_start_from=warm_start_path) return estimator
def model_to_estimator(keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None, config=None): """Constructs an `Estimator` instance from given keras model. For usage example, please see: [Creating estimators from Keras Models](https://tensorflow.org/guide/estimators#model_to_estimator). Args: keras_model: A compiled Keras model object. This argument is mutually exclusive with `keras_model_path`. keras_model_path: Path to a compiled Keras model saved on disk, in HDF5 format, which can be generated with the `save()` method of a Keras model. This argument is mutually exclusive with `keras_model`. custom_objects: Dictionary for custom objects. model_dir: Directory to save `Estimator` model parameters, graph, summary files for TensorBoard, etc. config: `RunConfig` to config `Estimator`. Returns: An Estimator from given keras model. Raises: ValueError: if neither keras_model nor keras_model_path was given. ValueError: if both keras_model and keras_model_path was given. ValueError: if the keras_model_path is a GCS URI. ValueError: if keras_model has not been compiled. """ if not (keras_model or keras_model_path): raise ValueError( 'Either `keras_model` or `keras_model_path` needs to be provided.') if keras_model and keras_model_path: raise ValueError( 'Please specity either `keras_model` or `keras_model_path`, ' 'but not both.') if not keras_model: if keras_model_path.startswith( 'gs://') or 'storage.googleapis.com' in keras_model_path: raise ValueError( '%s is not a local path. Please copy the model locally first.' % keras_model_path) logging.info('Loading models from %s', keras_model_path) keras_model = models.load_model(keras_model_path) else: logging.info('Using the Keras model provided.') keras_model = keras_model if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer: raise ValueError( 'The given keras model has not been compiled yet. ' 'Please compile the model with `model.compile()` ' 'before calling `model_to_estimator()`.') config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config, model_dir) keras_model_fn = _create_keras_model_fn(keras_model, custom_objects) if _any_weight_initialized(keras_model): # Warn if config passed to estimator tries to update GPUOptions. If a # session has already been created, the GPUOptions passed to the first # session sticks. if config.session_config.HasField('gpu_options'): logging.warning( 'The Keras backend session has already been set. ' 'The _session_config passed to model_to_estimator will not be used.') else: # Pass the config into keras backend's default session. sess = session.Session(config=config.session_config) K.set_session(sess) warm_start_path = None if keras_model._is_graph_network: warm_start_path = _save_first_checkpoint(keras_model, custom_objects, config) elif keras_model.built: logging.warning('You are creating an Estimator from a Keras model manually ' 'subclassed from `Model`, that was already called on some ' 'inputs (and thus already had weights). We are currently ' 'unable to preserve the model\'s state (its weights) as ' 'part of the estimator in this case. Be warned that the ' 'estimator has been created using a freshly initialized ' 'version of your model.\n' 'Note that this doesn\'t affect the state of the model ' 'instance you passed as `keras_model` argument.') estimator = estimator_lib.Estimator(keras_model_fn, config=config, warm_start_from=warm_start_path) return estimator