def test_extend_export_strategy_raises_error(self):
        def _base_export_fn(unused_estimator,
                            export_dir_base,
                            unused_checkpoint_path=None):
            base_path = os.path.join(export_dir_base, "e1")
            gfile.MkDir(base_path)
            return base_path

        def _post_export_fn(unused_orig_path, unused_new_path):
            return tempfile.mkdtemp()

        base_export_strategy = export_strategy_lib.ExportStrategy(
            "Servo", _base_export_fn)

        final_export_strategy = saved_model_export_utils.extend_export_strategy(
            base_export_strategy, _post_export_fn)

        test_estimator = TestEstimator()
        tmpdir = tempfile.mkdtemp()
        with self.assertRaises(ValueError) as ve:
            final_export_strategy.export(test_estimator, tmpdir,
                                         os.path.join(tmpdir, "checkpoint"))

        self.assertTrue(
            "post_export_fn must return a sub-directory" in str(ve.exception))
  def test_extend_export_strategy_raises_error(self):

    def _base_export_fn(unused_estimator,
                        export_dir_base,
                        unused_checkpoint_path=None):
      base_path = os.path.join(export_dir_base, "e1")
      gfile.MkDir(base_path)
      return base_path

    def _post_export_fn(unused_orig_path, unused_new_path):
      return tempfile.mkdtemp()

    base_export_strategy = export_strategy_lib.ExportStrategy(
        "Servo", _base_export_fn)

    final_export_strategy = saved_model_export_utils.extend_export_strategy(
        base_export_strategy, _post_export_fn)

    test_estimator = TestEstimator()
    tmpdir = tempfile.mkdtemp()
    with self.assertRaises(ValueError) as ve:
      final_export_strategy.export(test_estimator, tmpdir,
                                   os.path.join(tmpdir, "checkpoint"))

    self.assertTrue(
        "post_export_fn must return a sub-directory" in str(ve.exception))
    def test_extend_export_strategy_same_name(self):
        def _base_export_fn(unused_estimator,
                            export_dir_base,
                            unused_checkpoint_path=None):
            base_path = os.path.join(export_dir_base, "e1")
            gfile.MkDir(base_path)
            return base_path

        def _post_export_fn(orig_path, new_path):
            assert orig_path.endswith("/e1")
            post_export_path = os.path.join(new_path, "rewrite")
            gfile.MkDir(post_export_path)
            return post_export_path

        base_export_strategy = export_strategy_lib.ExportStrategy(
            "Servo", _base_export_fn)

        final_export_strategy = saved_model_export_utils.extend_export_strategy(
            base_export_strategy, _post_export_fn)
        self.assertEqual(final_export_strategy.name, "Servo")

        test_estimator = TestEstimator()
        tmpdir = tempfile.mkdtemp()
        export_model_dir = os.path.join(tmpdir, "model")
        checkpoint_path = os.path.join(tmpdir, "checkpoint")
        final_path = final_export_strategy.export(test_estimator,
                                                  export_model_dir,
                                                  checkpoint_path)
        self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path)
  def test_extend_export_strategy_same_name(self):

    def _base_export_fn(unused_estimator,
                        export_dir_base,
                        unused_checkpoint_path=None):
      base_path = os.path.join(export_dir_base, "e1")
      gfile.MkDir(base_path)
      return base_path

    def _post_export_fn(orig_path, new_path):
      assert orig_path.endswith("/e1")
      post_export_path = os.path.join(new_path, "rewrite")
      gfile.MkDir(post_export_path)
      return post_export_path

    base_export_strategy = export_strategy_lib.ExportStrategy(
        "Servo", _base_export_fn)

    final_export_strategy = saved_model_export_utils.extend_export_strategy(
        base_export_strategy, _post_export_fn)
    self.assertEqual(final_export_strategy.name, "Servo")

    test_estimator = TestEstimator()
    tmpdir = tempfile.mkdtemp()
    export_model_dir = os.path.join(tmpdir, "model")
    checkpoint_path = os.path.join(tmpdir, "checkpoint")
    final_path = final_export_strategy.export(test_estimator, export_model_dir,
                                              checkpoint_path)
    self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path)
예제 #5
0
  def test_extend_export_strategy(self):
    def _base_export_fn(unused_estimator, export_dir_base,
                        unused_checkpoint_path=None):
      return export_dir_base + "/e1"

    def _post_export_fn(orig_path):
      return orig_path + "/rewrite"

    base_export_strategy = export_strategy_lib.ExportStrategy(
        "Servo", _base_export_fn)

    final_export_strategy = saved_model_export_utils.extend_export_strategy(
        base_export_strategy, _post_export_fn, "Servo2")
    self.assertEqual(final_export_strategy.name, "Servo2")

    test_estimator = TestEstimator()
    final_path = final_export_strategy.export(test_estimator, "/path/to/orig",
                                              "/path/to/checkpoint")
    self.assertEqual("/path/to/orig/e1/rewrite", final_path)