Exemple #1
0
  def test_make_best_model_export_strategy(self):
    export_dir_base = tempfile.mkdtemp() + "export/"
    gfile.MkDir(export_dir_base)

    test_estimator = TestEstimator()
    export_strategy = saved_model_export_utils.make_best_model_export_strategy(
        serving_input_fn=None, exports_to_keep=3, compare_fn=None)

    self.assertNotEqual("",
                        export_strategy.export(test_estimator, export_dir_base,
                                               "fake_ckpt_0", {"loss": 100}))
    self.assertNotEqual("", test_estimator.last_exported_dir)
    self.assertNotEqual("", test_estimator.last_exported_checkpoint)

    self.assertEqual("",
                     export_strategy.export(test_estimator, export_dir_base,
                                            "fake_ckpt_1", {"loss": 101}))
    self.assertEqual(test_estimator.last_exported_dir,
                     os.path.join(export_dir_base, "fake_ckpt_0"))

    self.assertNotEqual("",
                        export_strategy.export(test_estimator, export_dir_base,
                                               "fake_ckpt_2", {"loss": 10}))
    self.assertEqual(test_estimator.last_exported_dir,
                     os.path.join(export_dir_base, "fake_ckpt_2"))

    self.assertEqual("",
                     export_strategy.export(test_estimator, export_dir_base,
                                            "fake_ckpt_3", {"loss": 20}))
    self.assertEqual(test_estimator.last_exported_dir,
                     os.path.join(export_dir_base, "fake_ckpt_2"))
Exemple #2
0
  def test_make_best_model_export_strategy_exceptions(self):
    export_dir_base = tempfile.mkdtemp() + "export/"

    test_estimator = TestEstimator()
    export_strategy = saved_model_export_utils.make_best_model_export_strategy(
        serving_input_fn=None, exports_to_keep=3, compare_fn=None)

    with self.assertRaises(ValueError):
      export_strategy.export(test_estimator, export_dir_base, "", {"loss": 200})

    with self.assertRaises(ValueError):
      export_strategy.export(test_estimator, export_dir_base, "fake_ckpt_1",
                             None)
Exemple #3
0
  def test_make_best_model_export_strategy_with_preemption(self):
    model_dir = self.get_temp_dir()
    eval_dir_base = os.path.join(model_dir, "eval_continuous")
    core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 50}, 1)
    core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 60}, 2)

    test_estimator = TestEstimator()
    export_strategy = saved_model_export_utils.make_best_model_export_strategy(
        serving_input_fn=None,
        exports_to_keep=3,
        model_dir=model_dir,
        event_file_pattern="eval_continuous/*.tfevents.*",
        compare_fn=None)

    export_dir_base = os.path.join(self.get_temp_dir(), "export")
    self.assertEqual("",
                     export_strategy.export(test_estimator, export_dir_base,
                                            "fake_ckpt_0", {
                                                "loss": 100
                                            }))
    self.assertEqual("", test_estimator.last_exported_dir)
    self.assertEqual("", test_estimator.last_exported_checkpoint)

    self.assertNotEqual("",
                        export_strategy.export(test_estimator, export_dir_base,
                                               "fake_ckpt_2", {
                                                   "loss": 10
                                               }))
    self.assertEqual(test_estimator.last_exported_dir,
                     os.path.join(export_dir_base, "fake_ckpt_2"))

    self.assertEqual("",
                     export_strategy.export(test_estimator, export_dir_base,
                                            "fake_ckpt_3", {
                                                "loss": 20
                                            }))
    self.assertEqual(test_estimator.last_exported_dir,
                     os.path.join(export_dir_base, "fake_ckpt_2"))