Beispiel #1
0
  def test_save_checkpoint_both_steps_and_secs_are_not_none(self):
    empty_config = run_config_lib.RunConfig()
    with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_ERR):
      empty_config.replace(save_checkpoints_steps=100,
                           save_checkpoints_secs=200)

    with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_ERR):
      run_config_lib.RunConfig(save_checkpoints_steps=100,
                               save_checkpoints_secs=200)
Beispiel #2
0
    def test_init_run_config_duplicate_distribute(self):
        with self.assertRaises(ValueError):
            run_config_lib.RunConfig(
                train_distribute=tf.distribute.MirroredStrategy(),
                experimental_distribute=DistributeConfig(
                    train_distribute=tf.distribute.MirroredStrategy()))

        with self.assertRaises(ValueError):
            run_config_lib.RunConfig(
                eval_distribute=tf.distribute.MirroredStrategy(),
                experimental_distribute=DistributeConfig(
                    eval_distribute=tf.distribute.MirroredStrategy()))
Beispiel #3
0
 def setUp(self):
     self._base_dir = os.path.join(self.get_temp_dir(),
                                   'keras_estimator_test')
     gfile.MakeDirs(self._base_dir)
     self._config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                             model_dir=self._base_dir)
     super(TestKerasEstimator, self).setUp()
Beispiel #4
0
 def test_incompatible_eval_strategy(self):
   with self.assertRaisesRegex(
       ValueError, 'Please use `tf.compat.v1.distribut'
       'e.experimental.ParameterServerStrategy`'):
     run_config_lib.RunConfig(
         eval_distribute=parameter_server_strategy_v2.ParameterServerStrategyV2
         .__new__(parameter_server_strategy_v2.ParameterServerStrategyV2))
Beispiel #5
0
  def test_with_empty_config(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.CategoricalAccuracy()])

    est_keras = keras_lib.model_to_estimator(
        keras_model=keras_model,
        model_dir=self._base_dir,
        config=run_config_lib.RunConfig())
    self.assertEqual(run_config_lib.get_default_session_config(),
                     est_keras._session_config)
    self.assertEqual(est_keras._session_config,
                     est_keras._config.session_config)
    self.assertEqual(self._base_dir, est_keras._config.model_dir)
    self.assertEqual(self._base_dir, est_keras._model_dir)

    est_keras = keras_lib.model_to_estimator(
        keras_model=keras_model, model_dir=self._base_dir, config=None)
    self.assertEqual(run_config_lib.get_default_session_config(),
                     est_keras._session_config)
    self.assertEqual(est_keras._session_config,
                     est_keras._config.session_config)
    self.assertEqual(self._base_dir, est_keras._config.model_dir)
    self.assertEqual(self._base_dir, est_keras._model_dir)
Beispiel #6
0
    def test_train_sequential_with_distribution_strategy(
            self, distribution, cloning):
        keras_model = simple_sequential_model()
        keras_model.compile(
            loss='categorical_crossentropy',
            metrics=[keras.metrics.CategoricalAccuracy()],
            optimizer=rmsprop_keras.RMSprop(learning_rate=0.01),
            cloning=cloning)
        config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                          model_dir=self._base_dir,
                                          train_distribute=distribution)
        with self.cached_session():
            est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
                                                     config=config)
            before_eval_results = est_keras.evaluate(
                input_fn=get_ds_test_input_fn, steps=1)
            est_keras.train(input_fn=get_ds_train_input_fn,
                            steps=_TRAIN_SIZE / 16)
            after_eval_results = est_keras.evaluate(
                input_fn=get_ds_test_input_fn, steps=1)
            self.assertLess(after_eval_results['loss'],
                            before_eval_results['loss'])

        writer_cache.FileWriterCache.clear()
        gfile.DeleteRecursively(self._config.model_dir)
    def test_replace_invalid_values(self):
        config = run_config_lib.RunConfig()

        with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):
            config.replace(model_dir='')
        with self.assertRaisesRegexp(ValueError, _SAVE_SUMMARY_STEPS_ERR):
            config.replace(save_summary_steps=-1)
        with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_STEPS_ERR):
            config.replace(save_checkpoints_steps=-1)
        with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_SECS_ERR):
            config.replace(save_checkpoints_secs=-1)
        with self.assertRaisesRegexp(ValueError, _SESSION_CONFIG_ERR):
            config.replace(session_config={})
        with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_MAX_ERR):
            config.replace(keep_checkpoint_max=-1)
        with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_HOURS_ERR):
            config.replace(keep_checkpoint_every_n_hours=0)
        with self.assertRaisesRegexp(ValueError,
                                     _SESSION_CREATION_TIMEOUT_SECS_ERR):
            config.replace(session_creation_timeout_secs=0)
        with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):
            config.replace(tf_random_seed=1.0)
        with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):
            config.replace(device_fn=lambda x, y: 0)
        with self.assertRaisesRegexp(ValueError,
                                     _EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR):
            config.replace(experimental_max_worker_delay_secs='5')
Beispiel #8
0
    def _get_estimator(self,
                       train_distribute,
                       eval_distribute,
                       remote_cluster=None):
        input_dimension = LABEL_DIMENSION
        linear_feature_columns = [
            tf.feature_column.numeric_column("x", shape=(input_dimension, ))
        ]
        dnn_feature_columns = [
            tf.feature_column.numeric_column("x", shape=(input_dimension, ))
        ]

        return dnn_linear_combined.DNNLinearCombinedRegressor(
            linear_feature_columns=linear_feature_columns,
            dnn_hidden_units=(2, 2),
            dnn_feature_columns=dnn_feature_columns,
            label_dimension=LABEL_DIMENSION,
            model_dir=self._model_dir,
            dnn_optimizer="Adagrad",
            linear_optimizer="Adagrad",
            config=run_config_lib.RunConfig(
                experimental_distribute=DistributeConfig(
                    train_distribute=train_distribute,
                    eval_distribute=eval_distribute,
                    remote_cluster=remote_cluster)))
Beispiel #9
0
 def test_previously_unexpected_cluster_spec(self):
     with tf.compat.v1.test.mock.patch.dict(
             "os.environ",
         {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}):
         run_config_lib.RunConfig(experimental_distribute=DistributeConfig(
             train_distribute=tf.distribute.MirroredStrategy(
                 ["/device:GPU:0", "/device:GPU:1"])))
Beispiel #10
0
 def test_save_checkpoint_flip_steps_to_none(self):
     config_with_steps = run_config_lib.RunConfig().replace(
         save_checkpoints_steps=100)
     config_without_ckpt = config_with_steps.replace(
         save_checkpoints_steps=None)
     self.assertIsNone(config_without_ckpt.save_checkpoints_steps)
     self.assertIsNone(config_without_ckpt.save_checkpoints_secs)
Beispiel #11
0
 def test_replace_with_disallowallowed_properties(self):
   config = run_config_lib.RunConfig()
   with self.assertRaises(ValueError):
     # tf_random_seed is not allowed to be replaced.
     config.replace(master='_master')
   with self.assertRaises(ValueError):
     config.replace(some_undefined_property=123)
Beispiel #12
0
 def setUp(self):
     super(TestEstimatorDistributionStrategy, self).setUp()
     strategy_combinations.set_virtual_cpus_to_at_least(3)
     self._base_dir = os.path.join(self.get_temp_dir(),
                                   'keras_to_estimator_strategy_test')
     gfile.MakeDirs(self._base_dir)
     self._config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                             model_dir=self._base_dir)
Beispiel #13
0
 def test_init_run_config_independent_worker(self):
     # When `train_distribute` is specified and TF_CONFIG is detected, use
     # distribute coordinator with INDEPENDENT_WORKER mode.
     with tf.compat.v1.test.mock.patch.dict(
             "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
         config = run_config_lib.RunConfig(
             train_distribute=tf.distribute.MirroredStrategy())
     self.assertEqual(config._distribute_coordinator_mode,
                      dc.CoordinatorMode.INDEPENDENT_WORKER)
Beispiel #14
0
 def test_init_run_config_standalone_client(self):
     # When `train_distribute` is specified, TF_CONFIG is detected and
     # `experimental.remote_cluster` is set use distribute coordinator with
     # STANDALONE_CLIENT mode.
     config = run_config_lib.RunConfig(
         train_distribute=tf.distribute.MirroredStrategy(),
         experimental_distribute=DistributeConfig(
             remote_cluster={"chief": ["fake_worker"]}))
     self.assertEqual(config._distribute_coordinator_mode,
                      dc.CoordinatorMode.STANDALONE_CLIENT)
Beispiel #15
0
    def test_should_run_distribute_coordinator(self):
        """Tests that should_run_distribute_coordinator return a correct value."""
        # We don't use distribute coordinator for local training.
        self.assertFalse(
            dc_training.should_run_distribute_coordinator(
                run_config_lib.RunConfig()))

        # When `train_distribute` is not specified, don't use distribute
        # coordinator.
        with tf.compat.v1.test.mock.patch.dict(
                "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
            self.assertFalse(
                dc_training.should_run_distribute_coordinator(
                    run_config_lib.RunConfig()))

        # When `train_distribute` is specified and TF_CONFIG is detected, use
        # distribute coordinator.
        with tf.compat.v1.test.mock.patch.dict(
                "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
            config_with_train_distribute = run_config_lib.RunConfig(
                experimental_distribute=DistributeConfig(
                    train_distribute=tf.distribute.MirroredStrategy(
                        ["/device:GPU:0", "/device:GPU:1"])))
            config_with_eval_distribute = run_config_lib.RunConfig(
                experimental_distribute=DistributeConfig(
                    eval_distribute=tf.distribute.MirroredStrategy(
                        ["/device:GPU:0", "/device:GPU:1"])))
        self.assertTrue(
            dc_training.should_run_distribute_coordinator(
                config_with_train_distribute))
        self.assertFalse(
            dc_training.should_run_distribute_coordinator(
                config_with_eval_distribute))

        # With a master in the cluster, don't run distribute coordinator.
        with tf.compat.v1.test.mock.patch.dict(
                "os.environ",
            {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
            config = run_config_lib.RunConfig(
                experimental_distribute=DistributeConfig(
                    train_distribute=tf.distribute.MirroredStrategy(
                        ["/device:GPU:0", "/device:GPU:1"])))
        self.assertFalse(dc_training.should_run_distribute_coordinator(config))
Beispiel #16
0
  def test_replace(self):
    config = run_config_lib.RunConfig()

    with self.assertRaisesRegexp(
        ValueError, _NOT_SUPPORTED_REPLACE_PROPERTY_MSG):
      # master is not allowed to be replaced.
      config.replace(master=_MASTER)

    with self.assertRaisesRegexp(
        ValueError, _NOT_SUPPORTED_REPLACE_PROPERTY_MSG):
      config.replace(some_undefined_property=_MASTER)
Beispiel #17
0
 def test_default_values(self):
   self._assert_distributed_properties(
       run_config=run_config_lib.RunConfig(),
       expected_cluster_spec={},
       expected_task_type=run_config_lib.TaskType.WORKER,
       expected_task_id=0,
       expected_master='',
       expected_evaluation_master='',
       expected_is_chief=True,
       expected_num_worker_replicas=1,
       expected_num_ps_replicas=0)
  def test_init_run_config_none_distribute_coordinator_mode(self):
    # We don't use distribute coordinator for local training.
    config = run_config_lib.RunConfig(
        train_distribute=tf.distribute.MirroredStrategy())
    dc_training.init_run_config(config, {})
    self.assertIsNone(config._distribute_coordinator_mode)

    # With a master in the cluster, don't run distribute coordinator.
    with tf.compat.v1.test.mock.patch.dict(
        "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
      config = run_config_lib.RunConfig(
          train_distribute=tf.distribute.MirroredStrategy())
      self.assertIsNone(config._distribute_coordinator_mode)

    # When `train_distribute` is not specified, don't use distribute
    # coordinator.
    with tf.compat.v1.test.mock.patch.dict(
        "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
      config = run_config_lib.RunConfig()
      self.assertFalse(hasattr(config, "_distribute_coordinator_mode"))
Beispiel #19
0
 def test_default_property_values(self):
   config = run_config_lib.RunConfig()
   self.assertIsNone(config.model_dir)
   self.assertIsNone(config.session_config)
   self.assertIsNone(config.tf_random_seed)
   self.assertEqual(100, config.save_summary_steps)
   self.assertEqual(600, config.save_checkpoints_secs)
   self.assertIsNone(config.save_checkpoints_steps)
   self.assertEqual(5, config.keep_checkpoint_max)
   self.assertEqual(10000, config.keep_checkpoint_every_n_hours)
   self.assertIsNone(config.service)
   self.assertIsNone(config.device_fn)
Beispiel #20
0
  def test_with_empty_config_and_empty_model_dir(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.CategoricalAccuracy()])

    with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, config=run_config_lib.RunConfig())
      self.assertEqual(est_keras._model_dir, _TMP_DIR)
Beispiel #21
0
 def test_init_invalid_values(self):
   with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):
     run_config_lib.RunConfig(model_dir='')
   with self.assertRaisesRegexp(ValueError, _SAVE_SUMMARY_STEPS_ERR):
     run_config_lib.RunConfig(save_summary_steps=-1)
   with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_STEPS_ERR):
     run_config_lib.RunConfig(save_checkpoints_steps=-1)
   with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_SECS_ERR):
     run_config_lib.RunConfig(save_checkpoints_secs=-1)
   with self.assertRaisesRegexp(ValueError, _SESSION_CONFIG_ERR):
     run_config_lib.RunConfig(session_config={})
   with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_MAX_ERR):
     run_config_lib.RunConfig(keep_checkpoint_max=-1)
   with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_HOURS_ERR):
     run_config_lib.RunConfig(keep_checkpoint_every_n_hours=0)
   with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):
     run_config_lib.RunConfig(tf_random_seed=1.0)
   with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):
     run_config_lib.RunConfig(device_fn=lambda x: '/cpu:0')
Beispiel #22
0
  def test_save_checkpoint(self):
    empty_config = run_config_lib.RunConfig()
    self.assertEqual(600, empty_config.save_checkpoints_secs)
    self.assertIsNone(empty_config.save_checkpoints_steps)

    config_with_steps = empty_config.replace(save_checkpoints_steps=100)
    del empty_config
    self.assertEqual(100, config_with_steps.save_checkpoints_steps)
    self.assertIsNone(config_with_steps.save_checkpoints_secs)

    config_with_secs = config_with_steps.replace(save_checkpoints_secs=200)
    del config_with_steps
    self.assertEqual(200, config_with_secs.save_checkpoints_secs)
    self.assertIsNone(config_with_secs.save_checkpoints_steps)
 def do_test_multi_inputs_multi_outputs_with_input_fn(
     self, distribution, train_input_fn, eval_input_fn):
   config = run_config_lib.RunConfig(
       tf_random_seed=_RANDOM_SEED,
       model_dir=self._base_dir,
       train_distribute=distribution)
   with self.cached_session():
     model = multi_inputs_multi_outputs_model()
     est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)
     baseline_eval_results = est_keras.evaluate(
         input_fn=eval_input_fn, steps=1)
     est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
     eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
     self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
Beispiel #24
0
  def test_with_conflicting_model_dir_and_config(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.CategoricalAccuracy()])

    with self.assertRaisesRegexp(
        ValueError, '`model_dir` are set both in '
        'constructor and `RunConfig`'):
      keras_lib.model_to_estimator(
          keras_model=keras_model,
          model_dir=self._base_dir,
          config=run_config_lib.RunConfig(model_dir=_TMP_DIR))
Beispiel #25
0
 def test_default_property_values(self):
   config = run_config_lib.RunConfig()
   self.assertIsNone(config.model_dir)
   self.assertIsNone(config.session_config)
   self.assertIsNone(config.tf_random_seed)
   self.assertEqual(100, config.save_summary_steps)
   self.assertEqual(600, config.save_checkpoints_secs)
   self.assertIsNone(config.save_checkpoints_steps)
   self.assertEqual(5, config.keep_checkpoint_max)
   self.assertEqual(10000, config.keep_checkpoint_every_n_hours)
   self.assertIsNone(config.service)
   self.assertIsNone(config.device_fn)
   self.assertIsNone(config.experimental_max_worker_delay_secs)
   self.assertEqual(7200, config.session_creation_timeout_secs)
   self.assertTrue(config.checkpoint_save_graph_def)
        def _test_metric_fn(metric_fn):
            input_fn = get_input_fn(x=np.arange(4)[:, None, None],
                                    y=np.ones(4)[:, None])
            config = run_config.RunConfig(log_step_count_steps=1)
            estimator = linear.LinearClassifierV2([fc.numeric_column('x')],
                                                  config=config)

            estimator = extenders.add_metrics(estimator, metric_fn)

            estimator.train(input_fn=input_fn)
            metrics = estimator.evaluate(input_fn=input_fn)
            self.assertIn('mean_x', metrics)
            self.assertEqual(1.5, metrics['mean_x'])
            # assert that it keeps original estimators metrics
            self.assertIn('auc', metrics)
    def test_estimator_with_strategy_hooks(self, distribution,
                                           use_train_and_evaluate):
        config = run_config.RunConfig(eval_distribute=distribution)

        def _input_map_fn(tensor):
            return {'feature': tensor}, tensor

        def input_fn():
            return tf.data.Dataset.from_tensors(
                [1.]).repeat(10).batch(5).map(_input_map_fn)

        def model_fn(features, labels, mode):
            del features, labels
            global_step = tf.compat.v1.train.get_global_step()
            if mode == model_fn_lib.ModeKeys.TRAIN:
                train_hook1 = tf.compat.v1.train.StepCounterHook(
                    every_n_steps=1, output_dir=self.get_temp_dir())
                train_hook2 = tf.compat.v1.test.mock.MagicMock(
                    wraps=tf.compat.v1.train.SessionRunHook(),
                    spec=tf.compat.v1.train.SessionRunHook)
                return model_fn_lib.EstimatorSpec(
                    mode,
                    loss=tf.constant(1.),
                    train_op=global_step.assign_add(1),
                    training_hooks=[train_hook1, train_hook2])
            if mode == model_fn_lib.ModeKeys.EVAL:
                eval_hook1 = tf.compat.v1.train.StepCounterHook(
                    every_n_steps=1, output_dir=self.get_temp_dir())
                eval_hook2 = tf.compat.v1.test.mock.MagicMock(
                    wraps=tf.compat.v1.train.SessionRunHook(),
                    spec=tf.compat.v1.train.SessionRunHook)
                return model_fn_lib.EstimatorSpec(
                    mode=mode,
                    loss=tf.constant(1.),
                    evaluation_hooks=[eval_hook1, eval_hook2])

        num_steps = 10
        estimator = estimator_lib.EstimatorV2(model_fn=model_fn,
                                              model_dir=self.get_temp_dir(),
                                              config=config)
        if use_train_and_evaluate:
            training.train_and_evaluate(
                estimator, training.TrainSpec(input_fn, max_steps=num_steps),
                training.EvalSpec(input_fn))
        else:
            estimator.train(input_fn, steps=num_steps)
            estimator.evaluate(input_fn, steps=num_steps)
Beispiel #28
0
    def test_init_with_allowed_properties(self):
        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        device_fn = lambda op: "/cpu:0"

        config = run_config_lib.RunConfig(tf_random_seed=11,
                                          save_summary_steps=12,
                                          save_checkpoints_secs=14,
                                          session_config=session_config,
                                          keep_checkpoint_max=16,
                                          keep_checkpoint_every_n_hours=17,
                                          device_fn=device_fn)
        self.assertEqual(11, config.tf_random_seed)
        self.assertEqual(12, config.save_summary_steps)
        self.assertEqual(14, config.save_checkpoints_secs)
        self.assertEqual(session_config, config.session_config)
        self.assertEqual(16, config.keep_checkpoint_max)
        self.assertEqual(17, config.keep_checkpoint_every_n_hours)
        self.assertEqual(device_fn, config.device_fn)
Beispiel #29
0
 def test_init_none_value(self):
     config = run_config_lib.RunConfig(tf_random_seed=None,
                                       model_dir=None,
                                       save_summary_steps=None,
                                       save_checkpoints_secs=None,
                                       save_checkpoints_steps=None,
                                       session_config=None,
                                       keep_checkpoint_max=None,
                                       keep_checkpoint_every_n_hours=None,
                                       device_fn=None)
     self.assertIsNone(config.tf_random_seed)
     self.assertIsNone(config.model_dir)
     self.assertIsNone(config.save_summary_steps)
     self.assertIsNone(config.save_checkpoints_secs)
     self.assertIsNone(config.save_checkpoints_steps)
     self.assertIsNone(config.session_config)
     self.assertIsNone(config.keep_checkpoint_max)
     self.assertIsNone(config.keep_checkpoint_every_n_hours)
     self.assertIsNone(config.device_fn)
Beispiel #30
0
  def test_replace_with_allowed_properties(self):
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    device_fn = lambda op: '/cpu:0'

    config = run_config_lib.RunConfig().replace(
        tf_random_seed=11,
        save_summary_steps=12,
        save_checkpoints_secs=14,
        session_config=session_config,
        keep_checkpoint_max=16,
        keep_checkpoint_every_n_hours=17,
        device_fn=device_fn,
        session_creation_timeout_secs=18,
        checkpoint_save_graph_def=False)
    self.assertEqual(11, config.tf_random_seed)
    self.assertEqual(12, config.save_summary_steps)
    self.assertEqual(14, config.save_checkpoints_secs)
    self.assertEqual(session_config, config.session_config)
    self.assertEqual(16, config.keep_checkpoint_max)
    self.assertEqual(17, config.keep_checkpoint_every_n_hours)
    self.assertEqual(device_fn, config.device_fn)
    self.assertEqual(18, config.session_creation_timeout_secs)
    self.assertFalse(config.checkpoint_save_graph_def)