Пример #1
0
def setup_tpu_session(master):
  """Initializes and returns a Keras/TF session connected the TPU `master`."""
  session = tf_session.Session(
      target=master, config=config_pb2.ConfigProto(isolate_session_state=True))
  K.set_session(session)
  K.get_session().run(tpu.initialize_system())
  return session
Пример #2
0
def train():
    sess = tf.Session()
    K.set_session(sess)

    model = neuralnetwork()
    model_dir = os.path.join(os.getcwd(), "./models/%s" % (model_name))
    os.makedirs(model_dir, exist_ok=True)
    print("model_dir: ", model_dir)
    the_estimator = tf.keras.estimator.model_to_estimator(keras_model=model,
                                                          model_dir=model_dir)
    #the_estimator = KerasRegressor(build_fn=model, epochs=n_epochs, batch_size=batch_size, verbose=1)

    train_spec = tf.estimator.TrainSpec(
        input_fn=lambda: the_input_iterator(train_path,
                                            perform_shuffle=True,
                                            repeat_count=n_epochs,
                                            batch_size=batch_size),
        max_steps=n_steps)
    valid_spec = tf.estimator.EvalSpec(input_fn=lambda: the_input_iterator(
        valid_path, perform_shuffle=False, batch_size=1))

    tf.estimator.train_and_evaluate(the_estimator, train_spec, valid_spec)
Пример #3
0
def model_to_estimator(keras_model=None,
                       keras_model_path=None,
                       custom_objects=None,
                       model_dir=None,
                       config=None):
    """Constructs an `Estimator` instance from given keras model.

  For usage example, please see
  @{$programmers_guide/estimators$creating_estimators_from_keras_models}.

  Args:
    keras_model: Keras model in memory.
    keras_model_path: Directory to a keras model on disk.
    custom_objects: Dictionary for custom objects.
    model_dir: Directory to save Estimator model parameters, graph and etc.
    config: Configuration object.

  Returns:
    An Estimator from given keras model.

  Raises:
    ValueError: if neither keras_model nor keras_model_path was given.
    ValueError: if both keras_model and keras_model_path was given.
    ValueError: if the keras_model_path is a GCS URI.
    ValueError: if keras_model has not been compiled.
  """
    if (not keras_model) and (not keras_model_path):
        raise ValueError(
            'Either `keras_model` or `keras_model_path` needs to be provided.')
    if keras_model and keras_model_path:
        raise ValueError(
            'Please specity either `keras_model` or `keras_model_path`, '
            'but not both.')

    if not keras_model:
        if keras_model_path.startswith(
                'gs://') or 'storage.googleapis.com' in keras_model_path:
            raise ValueError(
                '%s is not a local path. Please copy the model locally first.'
                % keras_model_path)
        logging.info('Loading models from %s', keras_model_path)
        keras_model = models.load_model(keras_model_path)
    else:
        logging.info('Using the Keras model provided.')
        keras_model = keras_model

    if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
        raise ValueError(
            'The given keras model has not been compiled yet. Please compile first '
            'before calling `model_to_estimator`.')

    if isinstance(config, dict):
        config = run_config_lib.RunConfig(**config)

    keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
    estimator = estimator_lib.Estimator(keras_model_fn,
                                        model_dir=model_dir,
                                        config=config)

    # Pass the config into keras backend's default session.
    with session.Session(config=estimator._session_config) as sess:
        K.set_session(sess)

    keras_weights = keras_model.get_weights()
    if keras_model._is_graph_network:
        # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
        _save_first_checkpoint(keras_model, estimator, custom_objects,
                               keras_weights)
    elif keras_model.built:
        logging.warning('You are creating an Estimator from a Keras model '
                        'manually subclassed from `Model`, that was '
                        'already called on some inputs (and thus already had '
                        'weights). We are currently unable to preserve '
                        'the model\'s state (its weights) '
                        'as part of the estimator '
                        'in this case. Be warned that the estimator '
                        'has been created using '
                        'a freshly initialized version of your model.\n'
                        'Note that this doesn\'t affect the state of the '
                        'model instance you passed as `keras_model` argument.')
    return estimator
Пример #4
0
def model_to_estimator(keras_model=None,
                       keras_model_path=None,
                       custom_objects=None,
                       model_dir=None,
                       config=None):
  """Constructs an `Estimator` instance from given keras model.

  For usage example, please see
  @{$programmers_guide/estimators$creating_estimators_from_keras_models}.

  Args:
    keras_model: Keras model in memory.
    keras_model_path: Directory to a keras model on disk.
    custom_objects: Dictionary for custom objects.
    model_dir: Directory to save Estimator model parameters, graph and etc.
    config: Configuration object.

  Returns:
    An Estimator from given keras model.

  Raises:
    ValueError: if neither keras_model nor keras_model_path was given.
    ValueError: if both keras_model and keras_model_path was given.
    ValueError: if the keras_model_path is a GCS URI.
    ValueError: if keras_model has not been compiled.
  """
  if (not keras_model) and (not keras_model_path):
    raise ValueError(
        'Either `keras_model` or `keras_model_path` needs to be provided.')
  if keras_model and keras_model_path:
    raise ValueError(
        'Please specity either `keras_model` or `keras_model_path`, '
        'but not both.')

  if not keras_model:
    if keras_model_path.startswith(
        'gs://') or 'storage.googleapis.com' in keras_model_path:
      raise ValueError(
          '%s is not a local path. Please copy the model locally first.' %
          keras_model_path)
    logging.info('Loading models from %s', keras_model_path)
    keras_model = models.load_model(keras_model_path)
  else:
    logging.info('Using the Keras model provided.')
    keras_model = keras_model

  if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
    raise ValueError(
        'The given keras model has not been compiled yet. Please compile first '
        'before calling `model_to_estimator`.')

  if isinstance(config, dict):
    config = run_config_lib.RunConfig(**config)

  keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
  estimator = estimator_lib.Estimator(
      keras_model_fn, model_dir=model_dir, config=config)

  # Check if we need to call get_weights:
  if _any_variable_initalized():
    keras_weights = keras_model.get_weights()
    # Warn if config passed to estimator tries to update GPUOptions. If a
    # session has already been created, the GPUOptions passed to the first
    # session sticks.
    if estimator._session_config.HasField('gpu_options'):
      logging.warning(
          'The Keras backend session has already been set. '
          'The _session_config passed to model_to_estimator will not be used.')
  else:
    # Pass the config into keras backend's default session.
    sess = session.Session(config=estimator._session_config)
    K.set_session(sess)
    keras_weights = None

  if keras_model._is_graph_network:
    # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
    _save_first_checkpoint(keras_model,
                           estimator,
                           custom_objects,
                           keras_weights)
  elif keras_model.built:
    logging.warning('You are creating an Estimator from a Keras model '
                    'manually subclassed from `Model`, that was '
                    'already called on some inputs (and thus already had '
                    'weights). We are currently unable to preserve '
                    'the model\'s state (its weights) '
                    'as part of the estimator '
                    'in this case. Be warned that the estimator '
                    'has been created using '
                    'a freshly initialized version of your model.\n'
                    'Note that this doesn\'t affect the state of the '
                    'model instance you passed as `keras_model` argument.')
  return estimator