def testRewritingExporterHandlesError(self, base_exporter_mock): base_exporter_mock.side_effect = _export_fn tr = self._TestRewriter(True) r_e = converters.RewritingExporter(self._base_exporter, tr) with self.assertRaisesRegex(ValueError, '.*rewrite-error'): r_e.export(self._estimator, self._export_path, self._checkpoint_path, self._eval_result, self._is_the_final_export) base_exporter_mock.assert_called_once_with(self._estimator, self._export_path, self._checkpoint_path, self._eval_result, self._is_the_final_export) self.assertTrue(tr.rewrite_called)
def testRewritingHandlesNoBaseExport(self, base_exporter_mock): base_exporter_mock.return_value = None tr = self._TestRewriter(False) r_e = converters.RewritingExporter(self._base_exporter, tr) final_path = r_e.export(self._estimator, self._export_path, self._checkpoint_path, self._eval_result, self._is_the_final_export) self.assertEqual(final_path, None) self.assertFalse(tr.rewrite_called) base_exporter_mock.assert_called_once_with(self._estimator, self._export_path, self._checkpoint_path, self._eval_result, self._is_the_final_export)
def testRewritingExporterSucceeds(self, base_exporter_mock): base_exporter_mock.side_effect = _export_fn tr = self._TestRewriter(False) r_e = converters.RewritingExporter(self._base_exporter, tr) final_path = r_e.export(self._estimator, self._export_path, self._checkpoint_path, self._eval_result, self._is_the_final_export) self.assertEqual(final_path, os.path.join(self._export_path, BASE_EXPORT_SUBDIR)) self.assertTrue( fileio.exists(os.path.join(final_path, REWRITTEN_SAVED_MODEL))) self.assertTrue( fileio.exists( os.path.join(final_path, tf.saved_model.ASSETS_DIRECTORY, REWRITTEN_VOCAB))) base_exporter_mock.assert_called_once_with(self._estimator, self._export_path, self._checkpoint_path, self._eval_result, self._is_the_final_export)