def test_non_concrete_error(self): root = tracking.Checkpointable() root.f = def_function.function(lambda x: 2. * x) root.f(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "must be converted to concrete functions"): export.export(root, export_dir, root.f)
def test_nested_outputs(self): root = tracking.Checkpointable() root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x))) root.f(constant_op.constant(1.)) to_export = root.f.get_concrete_function(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp(ValueError, "non-flat outputs"): export.export(root, export_dir, to_export)
def test_single_function_default_signature(self): model = tracking.Checkpointable() model.f = def_function.function(lambda: 3., input_signature=()) model.f() export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export(model, export_dir) self.assertAllClose({"output_0": 3.}, self._import_and_infer(export_dir, {}))
def test_nested_outputs(self): root = tracking.Checkpointable() root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x))) root.f(constant_op.constant(1.)) to_export = root.f.get_concrete_function(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "non-flat outputs"): export.export(root, export_dir, to_export)
def test_ambiguous_signatures(self): model = _ModelWithOptimizer() x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) model(x, y) model.second_function = def_function.function(lambda: 1.) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp(ValueError, "call.*second_function"): export.export(model, export_dir)
def test_method_export_signature(self): root = tracking.Checkpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.f(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export(root, export_dir, root.f) self.assertEqual({"output_0": 2.}, self._import_and_infer(export_dir, {"x": 1.}))
def test_nested_dict_outputs(self): root = tracking.Checkpointable() root.f = def_function.function( lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)}) root.f(constant_op.constant(1.)) to_export = root.f.get_concrete_function(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "dictionary containing non-Tensor value"): export.export(root, export_dir, to_export)
def test_nested_inputs(self): root = tracking.Checkpointable() root.f = def_function.function(lambda x: 2. * x[0]) root.f([constant_op.constant(1.)]) to_export = root.f.get_concrete_function( [constant_op.constant(1.), constant_op.constant(2.)]) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "non-unique argument names"): export.export(root, export_dir, to_export)
def test_no_reference_cycles(self): x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) self._model(x, y) if sys.version_info[0] < 3: # TODO(allenl): debug reference cycles in Python 2.x self.skipTest( "This test only works in Python 3+. Reference cycles are " "created in older Python versions.") export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export(self._model, export_dir, self._model.call)
def test_variable(self): root = tracking.Checkpointable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) root.f(constant_op.constant(1.)) to_export = root.f.get_concrete_function(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export(root, export_dir, to_export) self.assertAllEqual({"output_0": 12.}, self._import_and_infer(export_dir, {"x": 2.}))
def test_method_export_signature(self): root = tracking.Checkpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.f(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export(root, export_dir, root.f) self.assertEqual( {"output_0": 2.}, self._import_and_infer(export_dir, {"x": 1.}))
def test_single_method_default_signature(self): model = _ModelWithOptimizer() x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) model(x, y) export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export(model, export_dir) self.assertIn( "loss", self._import_and_infer(export_dir, { "x": [[3., 4.]], "y": [2.] }))
def test_method_export_concrete(self): root = tracking.Checkpointable() root.f = def_function.function(lambda z: {"out": 2. * z}) root.f(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export( root, export_dir, { "non_default_key": root.f.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.float32)) }) self.assertEqual({"out": 2.}, self._import_and_infer( export_dir, {"z": 1.}, signature_key="non_default_key"))
def test_optimizer(self): x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) model = _ModelWithOptimizer() first_loss = model(x, y) export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export(model, export_dir, model.call) second_loss = model(x, y) self.assertNotEqual(first_loss, second_loss) self.assertAllClose( second_loss, self._import_and_infer(export_dir, { "x": [[3., 4.]], "y": [2.] }))
def test_method_export_concrete(self): root = tracking.Checkpointable() root.f = def_function.function( lambda z: {"out": 2. * z}) root.f(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export( root, export_dir, {"non_default_key": root.f.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.float32))}) self.assertEqual( {"out": 2.}, self._import_and_infer( export_dir, {"z": 1.}, signature_key="non_default_key"))
def test_trivial_export_exception(self): export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp(ValueError, "signature"): export.export(tracking.Checkpointable(), export_dir)