Esempio n. 1
0
    def testRewritingExporterHandlesError(self, base_exporter_mock):

        base_exporter_mock.side_effect = _export_fn

        tr = self._TestRewriter(True)
        r_e = converters.RewritingExporter(self._base_exporter, tr)
        with self.assertRaisesRegex(ValueError, '.*rewrite-error'):
            r_e.export(self._estimator, self._export_path,
                       self._checkpoint_path, self._eval_result,
                       self._is_the_final_export)
        base_exporter_mock.assert_called_once_with(self._estimator,
                                                   self._export_path,
                                                   self._checkpoint_path,
                                                   self._eval_result,
                                                   self._is_the_final_export)
        self.assertTrue(tr.rewrite_called)
Esempio n. 2
0
    def testRewritingHandlesNoBaseExport(self, base_exporter_mock):

        base_exporter_mock.return_value = None

        tr = self._TestRewriter(False)
        r_e = converters.RewritingExporter(self._base_exporter, tr)
        final_path = r_e.export(self._estimator, self._export_path,
                                self._checkpoint_path, self._eval_result,
                                self._is_the_final_export)
        self.assertEqual(final_path, None)
        self.assertFalse(tr.rewrite_called)

        base_exporter_mock.assert_called_once_with(self._estimator,
                                                   self._export_path,
                                                   self._checkpoint_path,
                                                   self._eval_result,
                                                   self._is_the_final_export)
Esempio n. 3
0
    def testRewritingExporterSucceeds(self, base_exporter_mock):

        base_exporter_mock.side_effect = _export_fn

        tr = self._TestRewriter(False)
        r_e = converters.RewritingExporter(self._base_exporter, tr)
        final_path = r_e.export(self._estimator, self._export_path,
                                self._checkpoint_path, self._eval_result,
                                self._is_the_final_export)
        self.assertEqual(final_path,
                         os.path.join(self._export_path, BASE_EXPORT_SUBDIR))
        self.assertTrue(
            fileio.exists(os.path.join(final_path, REWRITTEN_SAVED_MODEL)))
        self.assertTrue(
            fileio.exists(
                os.path.join(final_path, tf.saved_model.ASSETS_DIRECTORY,
                             REWRITTEN_VOCAB)))

        base_exporter_mock.assert_called_once_with(self._estimator,
                                                   self._export_path,
                                                   self._checkpoint_path,
                                                   self._eval_result,
                                                   self._is_the_final_export)