def test_latest_exporter(self):

    def _serving_input_receiver_fn():
      pass

    export_dir_base = tempfile.mkdtemp() + "export/"
    gfile.MkDir(export_dir_base)

    exporter = exporter_lib.LatestExporter(
        name="latest_exporter",
        serving_input_receiver_fn=_serving_input_receiver_fn,
        assets_extra={"from/path": "to/path"},
        as_text=False,
        exports_to_keep=5)
    estimator = test.mock.Mock(spec=estimator_lib.Estimator)
    estimator.export_savedmodel.return_value = "export_result_path"

    export_result = exporter.export(estimator, export_dir_base,
                                    "checkpoint_path", {}, False)

    self.assertEqual("export_result_path", export_result)
    estimator.export_savedmodel.assert_called_with(
        export_dir_base,
        _serving_input_receiver_fn,
        assets_extra={"from/path": "to/path"},
        as_text=False,
        checkpoint_path="checkpoint_path",
        strip_default_attrs=True)
  def test_garbage_collect_exports(self):
    export_dir_base = tempfile.mkdtemp() + "export/"
    gfile.MkDir(export_dir_base)
    export_dir_1 = _create_test_export_dir(export_dir_base)
    export_dir_2 = _create_test_export_dir(export_dir_base)
    export_dir_3 = _create_test_export_dir(export_dir_base)
    export_dir_4 = _create_test_export_dir(export_dir_base)

    self.assertTrue(gfile.Exists(export_dir_1))
    self.assertTrue(gfile.Exists(export_dir_2))
    self.assertTrue(gfile.Exists(export_dir_3))
    self.assertTrue(gfile.Exists(export_dir_4))

    def _serving_input_receiver_fn():
      return array_ops.constant([1]), None

    exporter = exporter_lib.LatestExporter(
        name="latest_exporter",
        serving_input_receiver_fn=_serving_input_receiver_fn,
        exports_to_keep=2)
    estimator = test.mock.Mock(spec=estimator_lib.Estimator)
    # Garbage collect all but the most recent 2 exports,
    # where recency is determined based on the timestamp directory names.
    exporter.export(estimator, export_dir_base, None, None, False)

    self.assertFalse(gfile.Exists(export_dir_1))
    self.assertFalse(gfile.Exists(export_dir_2))
    self.assertTrue(gfile.Exists(export_dir_3))
    self.assertTrue(gfile.Exists(export_dir_4))
  def test_error_out_if_exports_to_keep_is_zero(self):
    def _serving_input_receiver_fn():
      pass

    with self.assertRaisesRegexp(ValueError, "positive number"):
      exporter = exporter_lib.LatestExporter(
          name="latest_exporter",
          serving_input_receiver_fn=_serving_input_receiver_fn,
          exports_to_keep=0)
      self.assertEqual("latest_exporter", exporter.name)
Esempio n. 4
0
 def _get_exporter(self, name, fc):
   feature_spec = feature_column.make_parse_example_spec(fc)
   serving_input_receiver_fn = (
       export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
   return exporter_lib.LatestExporter(
       name, serving_input_receiver_fn=serving_input_receiver_fn)