Ejemplo n.º 1
0
    def test_with_empty_config(self):
        keras_model, _, _, _, _ = get_resource_for_simple_model(
            model_type='sequential', is_evaluate=True)
        keras_model.compile(
            loss='categorical_crossentropy',
            optimizer='rmsprop',
            metrics=['mse', keras.metrics.categorical_accuracy])

        with self.test_session():
            est_keras = keras_lib.model_to_estimator(
                keras_model=keras_model,
                model_dir=self._base_dir,
                config=run_config_lib.RunConfig())
            self.assertEqual(run_config_lib.get_default_session_config(),
                             est_keras._session_config)
            self.assertEqual(est_keras._session_config,
                             est_keras._config.session_config)
            self.assertEqual(self._base_dir, est_keras._config.model_dir)
            self.assertEqual(self._base_dir, est_keras._model_dir)

        with self.test_session():
            est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
                                                     model_dir=self._base_dir,
                                                     config=None)
            self.assertEqual(run_config_lib.get_default_session_config(),
                             est_keras._session_config)
            self.assertEqual(est_keras._session_config,
                             est_keras._config.session_config)
            self.assertEqual(self._base_dir, est_keras._config.model_dir)
            self.assertEqual(self._base_dir, est_keras._model_dir)
Ejemplo n.º 2
0
  def test_with_empty_config(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.categorical_accuracy])

    with self.test_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, model_dir=self._base_dir,
          config=run_config_lib.RunConfig())
      self.assertEqual(run_config_lib.get_default_session_config(),
                       est_keras._session_config)
      self.assertEqual(est_keras._session_config,
                       est_keras._config.session_config)
      self.assertEqual(self._base_dir, est_keras._config.model_dir)
      self.assertEqual(self._base_dir, est_keras._model_dir)

    with self.test_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, model_dir=self._base_dir,
          config=None)
      self.assertEqual(run_config_lib.get_default_session_config(),
                       est_keras._session_config)
      self.assertEqual(est_keras._session_config,
                       est_keras._config.session_config)
      self.assertEqual(self._base_dir, est_keras._config.model_dir)
      self.assertEqual(self._base_dir, est_keras._model_dir)
Ejemplo n.º 3
0
def _maybe_overwrite_model_dir_and_session_config(config, model_dir):
  """Overwrite estimator config by `model_dir` and `session_config` if needed.

  Args:
    config: Original estimator config.
    model_dir: Estimator model checkpoint directory.

  Returns:
    Overwritten estimator config.

  Raises:
    ValueError: Model directory inconsistent between `model_dir` and `config`.
  """

  default_session_config = run_config_lib.get_default_session_config()
  if isinstance(config, dict):
    config = RunConfig(**config)
  elif config is None:
    config = RunConfig(session_config=default_session_config)
  if config.session_config is None:
    config = RunConfig.replace(config, session_config=default_session_config)

  if model_dir is not None:
    if (getattr(config, 'model_dir', None) is not None and
        config.model_dir != model_dir):
      raise ValueError(
          "`model_dir` are set both in constructor and `RunConfig`, but with "
          "different values. In constructor: '{}', in `RunConfig`: "
          "'{}' ".format(model_dir, config.model_dir))
    config = RunConfig.replace(config, model_dir=model_dir)
  elif getattr(config, 'model_dir', None) is None:
    model_dir = tempfile.mkdtemp()
    config = RunConfig.replace(config, model_dir=model_dir)

  return config