Exemple #1
0
  def test_with_empty_config(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.categorical_accuracy])

    with self.test_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, model_dir=self._base_dir,
          config=run_config_lib.RunConfig())
      self.assertEqual(run_config_lib.get_default_session_config(),
                       est_keras._session_config)
      self.assertEqual(est_keras._session_config,
                       est_keras._config.session_config)
      self.assertEqual(self._base_dir, est_keras._config.model_dir)
      self.assertEqual(self._base_dir, est_keras._model_dir)

    with self.test_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, model_dir=self._base_dir,
          config=None)
      self.assertEqual(run_config_lib.get_default_session_config(),
                       est_keras._session_config)
      self.assertEqual(est_keras._session_config,
                       est_keras._config.session_config)
      self.assertEqual(self._base_dir, est_keras._config.model_dir)
      self.assertEqual(self._base_dir, est_keras._model_dir)
Exemple #2
0
    def test_custom_objects(self):
        def relu6(x):
            return keras.backend.relu(x, max_value=6)

        keras_model = simple_functional_model(activation=relu6)
        keras_model.compile(loss='categorical_crossentropy', optimizer='adam')
        custom_objects = {'relu6': relu6}

        (x_train,
         y_train), _ = testing_utils.get_test_data(train_samples=_TRAIN_SIZE,
                                                   test_samples=50,
                                                   input_shape=(10, ),
                                                   num_classes=2)
        y_train = keras.utils.to_categorical(y_train, 2)
        input_name = keras_model.input_names[0]
        output_name = keras_model.output_names[0]
        train_input_fn = numpy_io.numpy_input_fn(
            x=randomize_io_type(x_train, input_name),
            y=randomize_io_type(y_train, output_name),
            shuffle=False,
            num_epochs=None,
            batch_size=16)
        with self.assertRaisesRegexp(ValueError, 'relu6'):
            with self.test_session():
                est = keras_lib.model_to_estimator(
                    keras_model=keras_model,
                    model_dir=tempfile.mkdtemp(dir=self._base_dir))
                est.train(input_fn=train_input_fn, steps=1)

        with self.test_session():
            est = keras_lib.model_to_estimator(
                keras_model=keras_model,
                model_dir=tempfile.mkdtemp(dir=self._base_dir),
                custom_objects=custom_objects)
            est.train(input_fn=train_input_fn, steps=1)
Exemple #3
0
    def test_with_empty_config(self):
        keras_model, _, _, _, _ = get_resource_for_simple_model(
            model_type='sequential', is_evaluate=True)
        keras_model.compile(
            loss='categorical_crossentropy',
            optimizer='rmsprop',
            metrics=['mse', keras.metrics.categorical_accuracy])

        with self.test_session():
            est_keras = keras_lib.model_to_estimator(
                keras_model=keras_model,
                model_dir=self._base_dir,
                config=run_config_lib.RunConfig())
            self.assertEqual(run_config_lib.get_default_session_config(),
                             est_keras._session_config)
            self.assertEqual(est_keras._session_config,
                             est_keras._config.session_config)
            self.assertEqual(self._base_dir, est_keras._config.model_dir)
            self.assertEqual(self._base_dir, est_keras._model_dir)

        with self.test_session():
            est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
                                                     model_dir=self._base_dir,
                                                     config=None)
            self.assertEqual(run_config_lib.get_default_session_config(),
                             est_keras._session_config)
            self.assertEqual(est_keras._session_config,
                             est_keras._config.session_config)
            self.assertEqual(self._base_dir, est_keras._config.model_dir)
            self.assertEqual(self._base_dir, est_keras._model_dir)
Exemple #4
0
  def test_with_conflicting_model_dir_and_config(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.categorical_accuracy])

    with self.test_session():
      with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
                                   'constructor and `RunConfig`'):
        keras_lib.model_to_estimator(
            keras_model=keras_model, model_dir=self._base_dir,
            config=run_config_lib.RunConfig(model_dir=_TMP_DIR))
Exemple #5
0
  def test_with_conflicting_model_dir_and_config(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.CategoricalAccuracy()])

    with self.cached_session():
      with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
                                   'constructor and `RunConfig`'):
        keras_lib.model_to_estimator(
            keras_model=keras_model, model_dir=self._base_dir,
            config=run_config_lib.RunConfig(model_dir=_TMP_DIR))
Exemple #6
0
    def test_train_functional_with_distribution_strategy(self):
        dist = mirrored_strategy.MirroredStrategy(
            devices=['/device:GPU:0', '/device:GPU:1'])
        keras_model = simple_functional_model()
        keras_model.compile(
            loss='categorical_crossentropy',
            optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
        config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                          model_dir=self._base_dir,
                                          train_distribute=dist,
                                          eval_distribute=dist)
        with self.cached_session():
            est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
                                                     config=config)
            before_eval_results = est_keras.evaluate(
                input_fn=get_ds_test_input_fn, steps=1)
            est_keras.train(input_fn=get_ds_train_input_fn,
                            steps=_TRAIN_SIZE / 16)
            after_eval_results = est_keras.evaluate(
                input_fn=get_ds_test_input_fn, steps=1)
            self.assertLess(after_eval_results['loss'],
                            before_eval_results['loss'])

        writer_cache.FileWriterCache.clear()
        gfile.DeleteRecursively(self._config.model_dir)
Exemple #7
0
    def test_train_with_tf_optimizer(self):
        for model_type in ['sequential', 'functional']:
            keras_model, (_, _), (
                _, _
            ), train_input_fn, eval_input_fn = get_resource_for_simple_model(
                model_type=model_type, is_evaluate=True)
            keras_model.compile(
                loss='categorical_crossentropy',
                optimizer=rmsprop.RMSPropOptimizer(1e-3),
                metrics=['mse', keras.metrics.categorical_accuracy])

            with self.test_session():
                est_keras = keras_lib.model_to_estimator(
                    keras_model=keras_model,
                    # Also use dict config argument to get test coverage for that line.
                    config={
                        'tf_random_seed': _RANDOM_SEED,
                        'model_dir': self._base_dir,
                    })
                before_eval_results = est_keras.evaluate(
                    input_fn=eval_input_fn, steps=1)
                est_keras.train(input_fn=train_input_fn,
                                steps=_TRAIN_SIZE / 16)
                after_eval_results = est_keras.evaluate(input_fn=eval_input_fn,
                                                        steps=1)
                self.assertLess(after_eval_results['loss'],
                                before_eval_results['loss'])

            writer_cache.FileWriterCache.clear()
            gfile.DeleteRecursively(self._config.model_dir)
Exemple #8
0
    def test_train_with_subclassed_model_with_existing_state(self):
        keras_model, (_, _), (
            _,
            _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
                model_type='subclass', is_evaluate=True)
        keras_model.compile(
            loss='categorical_crossentropy',
            optimizer=rmsprop.RMSPropOptimizer(1e-3),
            metrics=['mse', keras.metrics.categorical_accuracy])

        with self.test_session():
            # Create state
            keras_model.train_on_batch(np.random.random((10, ) + _INPUT_SIZE),
                                       np.random.random((10, _NUM_CLASS)))
            original_preds = keras_model.predict(np.ones((10, ) + _INPUT_SIZE))

            est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
                                                     config=self._config)
            est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
            before_eval_results = est_keras.evaluate(input_fn=eval_input_fn,
                                                     steps=1)
            est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
            after_eval_results = est_keras.evaluate(input_fn=eval_input_fn,
                                                    steps=1)
            self.assertLess(after_eval_results['loss'],
                            before_eval_results['loss'])

            # Check that original model state was not altered
            preds = keras_model.predict(np.ones((10, ) + _INPUT_SIZE))
            self.assertAllClose(original_preds, preds, atol=1e-5)
            # Check that the original model compilation did not break
            keras_model.train_on_batch(np.random.random((10, ) + _INPUT_SIZE),
                                       np.random.random((10, _NUM_CLASS)))
Exemple #9
0
  def test_train_with_subclassed_model_with_existing_state(self):
    keras_model, (_, _), (
        _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
            model_type='subclass', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=rmsprop.RMSPropOptimizer(1e-3),
        metrics=['mse', keras.metrics.categorical_accuracy])

    with self.test_session():
      # Create state
      keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
                                 np.random.random((10, _NUM_CLASS)))
      original_preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE))

      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
      before_eval_results = est_keras.evaluate(
          input_fn=eval_input_fn, steps=1)
      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
      after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])

      # Check that original model state was not altered
      preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE))
      self.assertAllClose(original_preds, preds, atol=1e-5)
      # Check that the original model compilation did not break
      keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
                                 np.random.random((10, _NUM_CLASS)))
Exemple #10
0
  def test_train_with_tf_optimizer(self):
    for model_type in ['sequential', 'functional']:
      keras_model, (_, _), (
          _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
              model_type=model_type, is_evaluate=True)
      keras_model.compile(
          loss='categorical_crossentropy',
          optimizer=rmsprop.RMSPropOptimizer(1e-3),
          metrics=['mse', keras.metrics.categorical_accuracy])

      with self.test_session():
        est_keras = keras_lib.model_to_estimator(
            keras_model=keras_model,
            # Also use dict config argument to get test coverage for that line.
            config={
                'tf_random_seed': _RANDOM_SEED,
                'model_dir': self._base_dir,
            })
        before_eval_results = est_keras.evaluate(
            input_fn=eval_input_fn, steps=1)
        est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
        after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
        self.assertLess(after_eval_results['loss'], before_eval_results['loss'])

      writer_cache.FileWriterCache.clear()
      gfile.DeleteRecursively(self._config.model_dir)
Exemple #11
0
    def test_invalid_ionames_error(self):
        (x_train,
         y_train), (_,
                    _) = testing_utils.get_test_data(train_samples=_TRAIN_SIZE,
                                                     test_samples=100,
                                                     input_shape=(10, ),
                                                     num_classes=2)
        y_train = keras.utils.to_categorical(y_train)

        def invald_input_name_input_fn():
            input_dict = {'invalid_input_name': x_train}
            return input_dict, y_train

        def invald_output_name_input_fn():
            input_dict = {'input_1': x_train}
            output_dict = {'invalid_output_name': y_train}
            return input_dict, output_dict

        model = simple_functional_model()
        model.compile(loss='categorical_crossentropy',
                      optimizer='adam',
                      metrics=['acc'])
        with self.test_session():
            est_keras = keras_lib.model_to_estimator(keras_model=model,
                                                     config=self._config)

        with self.test_session():
            with self.assertRaises(ValueError):
                est_keras.train(input_fn=invald_input_name_input_fn, steps=100)

            with self.assertRaises(ValueError):
                est_keras.train(input_fn=invald_output_name_input_fn,
                                steps=100)
Exemple #12
0
  def test_invalid_ionames_error(self):
    (x_train, y_train), (_, _) = testing_utils.get_test_data(
        train_samples=_TRAIN_SIZE,
        test_samples=100,
        input_shape=(10,),
        num_classes=2)
    y_train = keras.utils.to_categorical(y_train)

    def invald_input_name_input_fn():
      input_dict = {'invalid_input_name': x_train}
      return input_dict, y_train

    def invald_output_name_input_fn():
      input_dict = {'input_1': x_train}
      output_dict = {'invalid_output_name': y_train}
      return input_dict, output_dict

    model = simple_functional_model()
    model.compile(
        loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
    with self.test_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=model, config=self._config)

    with self.test_session():
      with self.assertRaises(ValueError):
        est_keras.train(input_fn=invald_input_name_input_fn, steps=100)

      with self.assertRaises(ValueError):
        est_keras.train(input_fn=invald_output_name_input_fn, steps=100)
Exemple #13
0
    def test_init_from_file(self):
        if h5py is None:
            return  # Skip test if models cannot be saved.

        keras_model, (x_train, y_train), (
            x_test, _), _, pred_input_fn = get_resource_for_simple_model(
                model_type='functional', is_evaluate=False)

        with self.test_session():
            keras_model.compile(loss='categorical_crossentropy',
                                optimizer='rmsprop',
                                metrics=['categorical_accuracy'])
            keras_model.fit(x_train, y_train, epochs=1)
            keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]
            fname = os.path.join(self._base_dir, 'keras_model.h5')
            keras.models.save_model(keras_model, fname)

        with self.test_session():
            keras_est = keras_lib.model_to_estimator(keras_model_path=fname,
                                                     config=self._config)
            est_pred = [
                np.argmax(y[keras_model.output_names[0]])
                for y in keras_est.predict(input_fn=pred_input_fn)
            ]
        self.assertAllEqual(est_pred, keras_pred)
Exemple #14
0
  def test_init_from_file(self):
    if h5py is None:
      return  # Skip test if models cannot be saved.

    keras_model, (x_train, y_train), (
        x_test, _), _, pred_input_fn = get_resource_for_simple_model(
            model_type='functional', is_evaluate=False)

    with self.test_session():
      keras_model.compile(
          loss='categorical_crossentropy',
          optimizer='rmsprop',
          metrics=['categorical_accuracy'])
      keras_model.fit(x_train, y_train, epochs=1)
      keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]
      fname = os.path.join(self._base_dir, 'keras_model.h5')
      keras.models.save_model(keras_model, fname)

    with self.test_session():
      keras_est = keras_lib.model_to_estimator(
          keras_model_path=fname, config=self._config)
      est_pred = [
          np.argmax(y[keras_model.output_names[0]])
          for y in keras_est.predict(input_fn=pred_input_fn)
      ]
    self.assertAllEqual(est_pred, keras_pred)
Exemple #15
0
    def test_train_with_hooks(self):
        for model_type in ['sequential', 'functional']:
            keras_model, (_, _), (
                _, _
            ), train_input_fn, eval_input_fn = get_resource_for_simple_model(
                model_type=model_type, is_evaluate=True)
            keras_model.compile(
                loss='categorical_crossentropy',
                optimizer=rmsprop.RMSPropOptimizer(1e-3),
                metrics=['mse', keras.metrics.categorical_accuracy])

            my_hook = MyHook()
            with self.test_session():
                est_keras = keras_lib.model_to_estimator(
                    keras_model=keras_model, config=self._config)
                before_eval_results = est_keras.evaluate(
                    input_fn=eval_input_fn, steps=1)
                est_keras.train(input_fn=train_input_fn,
                                hooks=[my_hook],
                                steps=_TRAIN_SIZE / 16)
                after_eval_results = est_keras.evaluate(input_fn=eval_input_fn,
                                                        steps=1)
                self.assertLess(after_eval_results['loss'],
                                before_eval_results['loss'])

            writer_cache.FileWriterCache.clear()
            gfile.DeleteRecursively(self._config.model_dir)
Exemple #16
0
  def test_gpu_config(self):
    with ops.Graph().as_default():
      keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
      keras_model.compile(
          loss='categorical_crossentropy',
          optimizer='rmsprop',
          metrics=['mse', keras.metrics.categorical_accuracy])

      gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
      sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
      self._config._session_config = sess_config
      keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      self.assertEqual(
          keras.backend.get_session()
          ._config.gpu_options.per_process_gpu_memory_fraction,
          gpu_options.per_process_gpu_memory_fraction)
Exemple #17
0
  def test_gpu_config(self):
    with ops.Graph().as_default():
      keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
      keras_model.compile(
          loss='categorical_crossentropy',
          optimizer='rmsprop',
          metrics=['mse', keras.metrics.categorical_accuracy])

      gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
      sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
      self._config._session_config = sess_config
      keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      self.assertEqual(
          keras.backend.get_session()
          ._config.gpu_options.per_process_gpu_memory_fraction,
          gpu_options.per_process_gpu_memory_fraction)
Exemple #18
0
 def test_pretrained_weights(self):
     keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
     keras_model.compile(
         loss='categorical_crossentropy',
         optimizer=rmsprop.RMSPropOptimizer(1e-3),
         metrics=['mse', keras.metrics.categorical_accuracy])
     with self.test_session():
         keras_model.train_on_batch(np.random.random((10, ) + _INPUT_SIZE),
                                    np.random.random((10, _NUM_CLASS)))
         weights = keras_model.get_weights()
         keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
         keras_model.set_weights(weights)
         keras_model.compile(
             loss='categorical_crossentropy',
             optimizer=SGD(lr=0.0001, momentum=0.9),
             metrics=['mse', keras.metrics.categorical_accuracy])
         keras_lib.model_to_estimator(keras_model=keras_model,
                                      config=self._config)
Exemple #19
0
    def test_custom_objects(self):
        keras_mobile = mobilenet.MobileNet(weights=None)
        keras_mobile.compile(loss='categorical_crossentropy', optimizer='adam')
        custom_objects = {
            'relu6': mobilenet.relu6,
            'DepthwiseConv2D': mobilenet.DepthwiseConv2D
        }
        with self.assertRaisesRegexp(ValueError, 'relu6'):
            with self.test_session():
                keras_lib.model_to_estimator(
                    keras_model=keras_mobile,
                    model_dir=tempfile.mkdtemp(dir=self._base_dir))

        with self.test_session():
            keras_lib.model_to_estimator(
                keras_model=keras_mobile,
                model_dir=tempfile.mkdtemp(dir=self._base_dir),
                custom_objects=custom_objects)
Exemple #20
0
  def test_custom_objects(self):
    keras_mobile = mobilenet.MobileNet(weights=None)
    keras_mobile.compile(loss='categorical_crossentropy', optimizer='adam')
    custom_objects = {
        'relu6': mobilenet.relu6,
        'DepthwiseConv2D': mobilenet.DepthwiseConv2D
    }
    with self.assertRaisesRegexp(ValueError, 'relu6'):
      with self.test_session():
        keras_lib.model_to_estimator(
            keras_model=keras_mobile,
            model_dir=tempfile.mkdtemp(dir=self._base_dir))

    with self.test_session():
      keras_lib.model_to_estimator(
          keras_model=keras_mobile,
          model_dir=tempfile.mkdtemp(dir=self._base_dir),
          custom_objects=custom_objects)
Exemple #21
0
 def test_pretrained_weights(self):
   keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
   keras_model.compile(
       loss='categorical_crossentropy',
       optimizer=rmsprop.RMSPropOptimizer(1e-3),
       metrics=['mse', keras.metrics.categorical_accuracy])
   with self.test_session():
     keras_model.train_on_batch(
         np.random.random((10,) + _INPUT_SIZE),
         np.random.random((10, _NUM_CLASS)))
     weights = keras_model.get_weights()
     keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
     keras_model.set_weights(weights)
     keras_model.compile(
         loss='categorical_crossentropy',
         optimizer=SGD(lr=0.0001, momentum=0.9),
         metrics=['mse', keras.metrics.categorical_accuracy])
     keras_lib.model_to_estimator(
         keras_model=keras_model, config=self._config)
Exemple #22
0
    def test_multi_inputs_multi_outputs(self):
        np.random.seed(_RANDOM_SEED)
        (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
            train_samples=_TRAIN_SIZE,
            test_samples=50,
            input_shape=(16, ),
            num_classes=3)
        np.random.seed(_RANDOM_SEED)
        (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
            train_samples=_TRAIN_SIZE,
            test_samples=50,
            input_shape=(16, ),
            num_classes=2)
        np.random.seed(_RANDOM_SEED)
        (input_m_train,
         _), (input_m_test,
              _) = testing_utils.get_test_data(train_samples=_TRAIN_SIZE,
                                               test_samples=50,
                                               input_shape=(8, ),
                                               num_classes=2)

        c_train = keras.utils.to_categorical(c_train)
        c_test = keras.utils.to_categorical(c_test)
        d_train = keras.utils.to_categorical(d_train)
        d_test = keras.utils.to_categorical(d_test)

        def train_input_fn():
            input_dict = {
                'input_a': a_train,
                'input_b': b_train,
                'input_m': input_m_train.astype(np.str)
            }
            output_dict = {'dense_2': c_train, 'dense_3': d_train}
            return input_dict, output_dict

        def eval_input_fn():
            input_dict = {
                'input_a': a_test,
                'input_b': b_test,
                'input_m': input_m_test.astype(np.str)
            }
            output_dict = {'dense_2': c_test, 'dense_3': d_test}
            return input_dict, output_dict

        with self.test_session():
            model = multi_inputs_multi_outputs_model()
            est_keras = keras_lib.model_to_estimator(keras_model=model,
                                                     config=self._config)
            before_eval_results = est_keras.evaluate(input_fn=eval_input_fn,
                                                     steps=1)
            est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
            after_eval_results = est_keras.evaluate(input_fn=eval_input_fn,
                                                    steps=1)
            self.assertLess(after_eval_results['loss'],
                            before_eval_results['loss'])
Exemple #23
0
 def do_test_multi_inputs_multi_outputs_with_input_fn(
     self, train_input_fn, eval_input_fn, pred_input_fn):
   with self.cached_session():
     model = multi_inputs_multi_outputs_model()
     est_keras = keras_lib.model_to_estimator(
         keras_model=model, config=self._config)
     baseline_eval_results = est_keras.evaluate(
         input_fn=eval_input_fn, steps=1)
     est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
     eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
     self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
     est_keras.predict(input_fn=pred_input_fn)
Exemple #24
0
 def do_test_multi_inputs_multi_outputs_with_input_fn(
     self, train_input_fn, eval_input_fn, pred_input_fn):
   with self.cached_session():
     model = multi_inputs_multi_outputs_model()
     est_keras = keras_lib.model_to_estimator(
         keras_model=model, config=self._config)
     baseline_eval_results = est_keras.evaluate(
         input_fn=eval_input_fn, steps=1)
     est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
     eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
     self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
     est_keras.predict(input_fn=pred_input_fn)
  def test_custom_objects(self):
    
    def relu6(x):
      return keras.backend.relu(x, max_value=6)
    
    keras_model = simple_functional_model(activation=relu6)
    keras_model.compile(loss='categorical_crossentropy', optimizer='adam')
    custom_objects = {
        'relu6': relu6
    }

    with self.assertRaisesRegexp(ValueError, 'relu6'):
      with self.test_session():
        keras_lib.model_to_estimator(
            keras_model=keras_model,
            model_dir=tempfile.mkdtemp(dir=self._base_dir))

    with self.test_session():
      keras_lib.model_to_estimator(
          keras_model=keras_model,
          model_dir=tempfile.mkdtemp(dir=self._base_dir),
          custom_objects=custom_objects)
Exemple #26
0
  def test_custom_objects(self):

    def relu6(x):
      return keras.backend.relu(x, max_value=6)

    keras_model = simple_functional_model(activation=relu6)
    keras_model.compile(loss='categorical_crossentropy', optimizer='adam')
    custom_objects = {
        'relu6': relu6
    }

    (x_train, y_train), _ = testing_utils.get_test_data(
        train_samples=_TRAIN_SIZE,
        test_samples=50,
        input_shape=(10,),
        num_classes=2)
    y_train = keras.utils.to_categorical(y_train, 2)
    input_name = keras_model.input_names[0]
    output_name = keras_model.output_names[0]
    train_input_fn = numpy_io.numpy_input_fn(
        x=randomize_io_type(x_train, input_name),
        y=randomize_io_type(y_train, output_name),
        shuffle=False,
        num_epochs=None,
        batch_size=16)
    with self.assertRaisesRegexp(ValueError, 'relu6'):
      with self.test_session():
        est = keras_lib.model_to_estimator(
            keras_model=keras_model,
            model_dir=tempfile.mkdtemp(dir=self._base_dir))
        est.train(input_fn=train_input_fn, steps=1)

    with self.test_session():
      est = keras_lib.model_to_estimator(
          keras_model=keras_model,
          model_dir=tempfile.mkdtemp(dir=self._base_dir),
          custom_objects=custom_objects)
      est.train(input_fn=train_input_fn, steps=1)
Exemple #27
0
  def test_tf_config(self):
    keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.categorical_accuracy])

    tf_config = json.dumps({
        'cluster': {
            run_config_lib.TaskType.PS: ['localhost:1234'],
            run_config_lib.TaskType.WORKER: ['localhost:1236'],
            run_config_lib.TaskType.MASTER: ['localhost:1238']
        },
        'task': {
            'type': run_config_lib.TaskType.MASTER,
            'index': 0
        }
    })
    with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
      with self.test_session():
        keras_lib.model_to_estimator(
            keras_model=keras_model,
            model_dir=tempfile.mkdtemp(dir=self._base_dir))
Exemple #28
0
  def test_with_empty_config_and_empty_model_dir(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.CategoricalAccuracy()])

    with self.cached_session():
      with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
        est_keras = keras_lib.model_to_estimator(
            keras_model=keras_model,
            config=run_config_lib.RunConfig())
        self.assertEqual(est_keras._model_dir, _TMP_DIR)
Exemple #29
0
    def test_tf_config(self):
        keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
        keras_model.compile(
            loss='categorical_crossentropy',
            optimizer='rmsprop',
            metrics=['mse', keras.metrics.categorical_accuracy])

        tf_config = json.dumps({
            'cluster': {
                run_config_lib.TaskType.PS: ['localhost:1234'],
                run_config_lib.TaskType.WORKER: ['localhost:1236'],
                run_config_lib.TaskType.MASTER: ['localhost:1238']
            },
            'task': {
                'type': run_config_lib.TaskType.MASTER,
                'index': 0
            }
        })
        with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
            with self.test_session():
                keras_lib.model_to_estimator(
                    keras_model=keras_model,
                    model_dir=tempfile.mkdtemp(dir=self._base_dir))
 def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn,
                                                      eval_input_fn):
   config = run_config_lib.RunConfig(
       tf_random_seed=_RANDOM_SEED,
       model_dir=self._base_dir,
       train_distribute=self._dist)
   with self.cached_session():
     model = multi_inputs_multi_outputs_model()
     est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)
     baseline_eval_results = est_keras.evaluate(
         input_fn=eval_input_fn, steps=1)
     est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
     eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
     self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
Exemple #31
0
 def do_test_multi_inputs_multi_outputs_with_input_fn(
     self, distribution, train_input_fn, eval_input_fn):
   config = run_config_lib.RunConfig(
       tf_random_seed=_RANDOM_SEED,
       model_dir=self._base_dir,
       train_distribute=distribution)
   with self.cached_session():
     model = multi_inputs_multi_outputs_model()
     est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)
     baseline_eval_results = est_keras.evaluate(
         input_fn=eval_input_fn, steps=1)
     est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
     eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
     self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
Exemple #32
0
  def test_with_empty_config_and_empty_model_dir(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.categorical_accuracy])

    with self.test_session():
      with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
        est_keras = keras_lib.model_to_estimator(
            keras_model=keras_model,
            config=run_config_lib.RunConfig())
        self.assertEqual(est_keras._model_dir, _TMP_DIR)
Exemple #33
0
  def test_multi_inputs_multi_outputs(self):
    np.random.seed(_RANDOM_SEED)
    (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
        train_samples=_TRAIN_SIZE,
        test_samples=50,
        input_shape=(16,),
        num_classes=3)
    np.random.seed(_RANDOM_SEED)
    (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
        train_samples=_TRAIN_SIZE,
        test_samples=50,
        input_shape=(16,),
        num_classes=2)
    np.random.seed(_RANDOM_SEED)
    (input_m_train, _), (input_m_test, _) = testing_utils.get_test_data(
        train_samples=_TRAIN_SIZE,
        test_samples=50,
        input_shape=(8,),
        num_classes=2)

    c_train = keras.utils.to_categorical(c_train)
    c_test = keras.utils.to_categorical(c_test)
    d_train = keras.utils.to_categorical(d_train)
    d_test = keras.utils.to_categorical(d_test)

    def train_input_fn():
      input_dict = {'input_a': a_train, 'input_b': b_train,
                    'input_m': input_m_train > 0}
      output_dict = {'dense_2': c_train, 'dense_3': d_train}
      return input_dict, output_dict

    def eval_input_fn():
      input_dict = {'input_a': a_test, 'input_b': b_test,
                    'input_m': input_m_test > 0}
      output_dict = {'dense_2': c_test, 'dense_3': d_test}
      return input_dict, output_dict

    with self.test_session():
      model = multi_inputs_multi_outputs_model()
      est_keras = keras_lib.model_to_estimator(
          keras_model=model, config=self._config)
      before_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
      after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
Exemple #34
0
  def test_train_with_subclassed_model(self):
    keras_model, (_, _), (
        _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
            model_type='subclass', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=rmsprop.RMSPropOptimizer(1e-3),
        metrics=['mse', keras.metrics.CategoricalAccuracy()])

    with self.cached_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
      before_eval_results = est_keras.evaluate(
          input_fn=eval_input_fn, steps=1)
      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
      after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
Exemple #35
0
  def test_train_with_subclassed_model(self):
    keras_model, (_, _), (
        _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
            model_type='subclass', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=rmsprop.RMSPropOptimizer(1e-3),
        metrics=['mse', keras.metrics.categorical_accuracy])

    with self.test_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
      before_eval_results = est_keras.evaluate(
          input_fn=eval_input_fn, steps=1)
      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
      after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
Exemple #36
0
  def test_keras_model_init_error(self):
    with self.assertRaisesRegexp(ValueError, 'Either'):
      keras_lib.model_to_estimator()

    with self.test_session():
      keras_model = simple_sequential_model()
      with self.assertRaisesRegexp(ValueError, 'not both'):
        keras_lib.model_to_estimator(
            keras_model=keras_model,
            keras_model_path=tempfile.mkdtemp(dir=self._base_dir))

    with self.test_session():
      keras_model = simple_sequential_model()
      with self.assertRaisesRegexp(ValueError, 'compiled'):
        keras_lib.model_to_estimator(keras_model=keras_model)

    with self.test_session():
      keras_model = simple_sequential_model()
      with self.assertRaisesRegexp(ValueError, 'not a local path'):
        keras_lib.model_to_estimator(
            keras_model_path='gs://bucket/object')
Exemple #37
0
    def test_keras_model_init_error(self):
        with self.assertRaisesRegexp(ValueError, 'Either'):
            keras_lib.model_to_estimator()

        with self.test_session():
            keras_model = simple_sequential_model()
            with self.assertRaisesRegexp(ValueError, 'not both'):
                keras_lib.model_to_estimator(
                    keras_model=keras_model,
                    keras_model_path=tempfile.mkdtemp(dir=self._base_dir))

        with self.test_session():
            keras_model = simple_sequential_model()
            with self.assertRaisesRegexp(ValueError, 'compiled'):
                keras_lib.model_to_estimator(keras_model=keras_model)

        with self.test_session():
            keras_model = simple_sequential_model()
            with self.assertRaisesRegexp(ValueError, 'not a local path'):
                keras_lib.model_to_estimator(
                    keras_model_path='gs://bucket/object')
Exemple #38
0
  def test_keras_optimizer_with_distribution_strategy(self, distribution):
    keras_model = simple_sequential_model()
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=keras.optimizers.rmsprop(lr=0.01))

    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                      model_dir=self._base_dir,
                                      train_distribute=distribution)
    with self.cached_session():
      est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
                                               config=config)
      with self.assertRaisesRegexp(ValueError,
                                   'Only TensorFlow native optimizers are '
                                   'supported with DistributionStrategy.'):
        est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)

    writer_cache.FileWriterCache.clear()
    gfile.DeleteRecursively(self._config.model_dir)
Exemple #39
0
  def test_keras_optimizer_with_distribution_strategy(self, distribution):
    keras_model = simple_sequential_model()
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=keras.optimizers.rmsprop(lr=0.01))

    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                      model_dir=self._base_dir,
                                      train_distribute=distribution)
    with self.cached_session():
      est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
                                               config=config)
      with self.assertRaisesRegexp(ValueError,
                                   'Only TensorFlow native optimizers are '
                                   'supported with DistributionStrategy.'):
        est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)

    writer_cache.FileWriterCache.clear()
    gfile.DeleteRecursively(self._config.model_dir)
Exemple #40
0
  def test_train_with_model_fit_and_hooks(self):
    keras_model, (x_train, y_train), _, \
      train_input_fn, eval_input_fn = get_resource_for_simple_model(
          model_type='sequential', is_evaluate=True)

    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=rmsprop.RMSPropOptimizer(1e-3),
        metrics=['mse', keras.metrics.CategoricalAccuracy()])
    my_hook = MyHook()
    with self.cached_session():
      keras_model.fit(x_train, y_train, epochs=1)

      keras_est = keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      before_eval_results = keras_est.evaluate(input_fn=eval_input_fn)
      keras_est.train(input_fn=train_input_fn, hooks=[my_hook],
                      steps=_TRAIN_SIZE / 16)
      after_eval_results = keras_est.evaluate(input_fn=eval_input_fn, steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
Exemple #41
0
  def test_train_with_model_fit_and_hooks(self):
    keras_model, (x_train, y_train), _, \
      train_input_fn, eval_input_fn = get_resource_for_simple_model(
          model_type='sequential', is_evaluate=True)

    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=rmsprop.RMSPropOptimizer(1e-3),
        metrics=['mse', keras.metrics.categorical_accuracy])
    my_hook = MyHook()
    with self.test_session():
      keras_model.fit(x_train, y_train, epochs=1)

      keras_est = keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      before_eval_results = keras_est.evaluate(input_fn=eval_input_fn)
      keras_est.train(input_fn=train_input_fn, hooks=[my_hook],
                      steps=_TRAIN_SIZE / 16)
      after_eval_results = keras_est.evaluate(input_fn=eval_input_fn, steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
Exemple #42
0
    def test_evaluate(self):
        keras_model, (x_train, y_train), (
            x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(
                model_type='functional', is_evaluate=True)

        with self.test_session():
            metrics = [
                'binary_accuracy', 'binary_crossentropy',
                'categorical_accuracy', 'categorical_crossentropy',
                'cosine_proximity', 'hinge', 'kullback_leibler_divergence',
                'mean_absolute_error', 'mean_absolute_percentage_error',
                'mean_squared_error', 'mean_squared_logarithmic_error',
                'poisson', 'squared_hinge', 'top_k_categorical_accuracy'
            ]
            keras_model.compile(loss='categorical_crossentropy',
                                optimizer='adam',
                                metrics=metrics)
            keras_model.fit(x_train, y_train, epochs=1)
            keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32)

        with self.test_session():
            keras_est = keras_lib.model_to_estimator(keras_model=keras_model,
                                                     config=self._config)
            est_eval = keras_est.evaluate(input_fn=eval_input_fn)

        metrics = ['loss'] + metrics

        # Check loss and all metrics match between keras and estimator.
        def shift(val):
            if val == 0:
                return 0
            else:
                return val / 10**int(log10(abs(val)))

        for i, metric_name in enumerate(metrics):
            self.assertAlmostEqual(
                shift(est_eval[metric_name]),
                shift(keras_eval[i]),
                places=4,
                msg='%s mismatch, keras model: %s, estimator: %s' %
                (metric_name, est_eval[metric_name], keras_eval[i]))
Exemple #43
0
  def test_train(self):
    for model_type in ['sequential', 'functional']:
      keras_model, (_, _), (
          _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
              model_type=model_type, is_evaluate=True)
      keras_model.compile(
          loss='categorical_crossentropy',
          optimizer='rmsprop',
          metrics=['mse', keras.metrics.categorical_accuracy])

      with self.test_session():
        est_keras = keras_lib.model_to_estimator(
            keras_model=keras_model, config=self._config)
        before_eval_results = est_keras.evaluate(
            input_fn=eval_input_fn, steps=1)
        est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
        after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
        self.assertLess(after_eval_results['loss'], before_eval_results['loss'])

      writer_cache.FileWriterCache.clear()
      gfile.DeleteRecursively(self._config.model_dir)
Exemple #44
0
  def test_train_sequential_with_distribution_strategy(self, distribution):
    keras_model = simple_sequential_model()
    keras_model.compile(
        loss='categorical_crossentropy',
        metrics=[keras.metrics.CategoricalAccuracy()],
        optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                      model_dir=self._base_dir,
                                      train_distribute=distribution)
    with self.cached_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, config=config)
      before_eval_results = est_keras.evaluate(
          input_fn=get_ds_test_input_fn, steps=1)
      est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
      after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn,
                                              steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])

    writer_cache.FileWriterCache.clear()
    gfile.DeleteRecursively(self._config.model_dir)
Exemple #45
0
    def test_predict(self):
        # Check that predict on a pretrained model yield the same result.
        keras_model, (x_train, y_train), (
            x_test, _), _, pred_input_fn = get_resource_for_simple_model(
                model_type='sequential', is_evaluate=False)

        with self.test_session():
            keras_model.compile(loss='categorical_crossentropy',
                                optimizer='adam',
                                metrics=['accuracy'])
            keras_model.fit(x_train, y_train, epochs=1)
            keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]

        with self.test_session():
            keras_est = keras_lib.model_to_estimator(keras_model=keras_model,
                                                     config=self._config)
            est_pred = [
                np.argmax(y[keras_model.output_names[0]])
                for y in keras_est.predict(input_fn=pred_input_fn)
            ]
        self.assertAllEqual(est_pred, keras_pred)
Exemple #46
0
  def test_evaluate(self):
    keras_model, (x_train, y_train), (
        x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(
            model_type='functional', is_evaluate=True)

    with self.test_session():
      metrics = [
          'binary_accuracy', 'binary_crossentropy', 'categorical_accuracy',
          'categorical_crossentropy', 'cosine_proximity', 'hinge',
          'kullback_leibler_divergence', 'mean_absolute_error',
          'mean_absolute_percentage_error', 'mean_squared_error',
          'mean_squared_logarithmic_error', 'poisson', 'squared_hinge',
          'top_k_categorical_accuracy'
      ]
      keras_model.compile(
          loss='categorical_crossentropy', optimizer='adam', metrics=metrics)
      keras_model.fit(x_train, y_train, epochs=1)
      keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32)

    with self.test_session():
      keras_est = keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      est_eval = keras_est.evaluate(input_fn=eval_input_fn)

    metrics = ['loss'] + metrics

    # Check loss and all metrics match between keras and estimator.
    def shift(val):
      if val == 0:
        return 0
      else:
        return val / 10**int(log10(abs(val)))

    for i, metric_name in enumerate(metrics):
      self.assertAlmostEqual(
          shift(est_eval[metric_name]),
          shift(keras_eval[i]),
          places=4,
          msg='%s mismatch, keras model: %s, estimator: %s' %
          (metric_name, est_eval[metric_name], keras_eval[i]))
Exemple #47
0
  def test_predict(self):
    # Check that predict on a pretrained model yield the same result.
    keras_model, (x_train, y_train), (
        x_test, _), _, pred_input_fn = get_resource_for_simple_model(
            model_type='sequential', is_evaluate=False)

    with self.test_session():
      keras_model.compile(
          loss='categorical_crossentropy',
          optimizer='adam',
          metrics=['accuracy'])
      keras_model.fit(x_train, y_train, epochs=1)
      keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]

    with self.test_session():
      keras_est = keras_lib.model_to_estimator(
          keras_model=keras_model, config=self._config)
      est_pred = [
          np.argmax(y[keras_model.output_names[0]])
          for y in keras_est.predict(input_fn=pred_input_fn)
      ]
    self.assertAllEqual(est_pred, keras_pred)
Exemple #48
0
  def test_train_functional_with_distribution_strategy(self):
    dist = mirrored_strategy.MirroredStrategy(
        devices=['/device:GPU:0', '/device:GPU:1'])
    keras_model = simple_functional_model()
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                      model_dir=self._base_dir,
                                      train_distribute=dist,
                                      eval_distribute=dist)
    with self.test_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, config=config)
      before_eval_results = est_keras.evaluate(
          input_fn=get_ds_test_input_fn, steps=1)
      est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
      after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn,
                                              steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])

    writer_cache.FileWriterCache.clear()
    gfile.DeleteRecursively(self._config.model_dir)