Пример #1
0
 def test_non_concrete_error(self):
     root = tracking.Checkpointable()
     root.f = def_function.function(lambda x: 2. * x)
     root.f(constant_op.constant(1.))
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegexp(
             ValueError, "must be converted to concrete functions"):
         export.export(root, export_dir, root.f)
Пример #2
0
 def test_non_concrete_error(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(lambda x: 2. * x)
   root.f(constant_op.constant(1.))
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "must be converted to concrete functions"):
     export.export(root, export_dir, root.f)
Пример #3
0
 def test_nested_outputs(self):
     root = tracking.Checkpointable()
     root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
     root.f(constant_op.constant(1.))
     to_export = root.f.get_concrete_function(constant_op.constant(1.))
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegexp(ValueError, "non-flat outputs"):
         export.export(root, export_dir, to_export)
Пример #4
0
 def test_single_function_default_signature(self):
     model = tracking.Checkpointable()
     model.f = def_function.function(lambda: 3., input_signature=())
     model.f()
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     export.export(model, export_dir)
     self.assertAllClose({"output_0": 3.},
                         self._import_and_infer(export_dir, {}))
Пример #5
0
 def test_nested_outputs(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
   root.f(constant_op.constant(1.))
   to_export = root.f.get_concrete_function(constant_op.constant(1.))
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "non-flat outputs"):
     export.export(root, export_dir, to_export)
Пример #6
0
 def test_ambiguous_signatures(self):
     model = _ModelWithOptimizer()
     x = constant_op.constant([[3., 4.]])
     y = constant_op.constant([2.])
     model(x, y)
     model.second_function = def_function.function(lambda: 1.)
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegexp(ValueError, "call.*second_function"):
         export.export(model, export_dir)
Пример #7
0
 def test_method_export_signature(self):
     root = tracking.Checkpointable()
     root.f = def_function.function(
         lambda x: 2. * x,
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
     root.f(constant_op.constant(1.))
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     export.export(root, export_dir, root.f)
     self.assertEqual({"output_0": 2.},
                      self._import_and_infer(export_dir, {"x": 1.}))
Пример #8
0
 def test_nested_dict_outputs(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(
       lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)})
   root.f(constant_op.constant(1.))
   to_export = root.f.get_concrete_function(constant_op.constant(1.))
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "dictionary containing non-Tensor value"):
     export.export(root, export_dir, to_export)
Пример #9
0
 def test_nested_dict_outputs(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(
       lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)})
   root.f(constant_op.constant(1.))
   to_export = root.f.get_concrete_function(constant_op.constant(1.))
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "dictionary containing non-Tensor value"):
     export.export(root, export_dir, to_export)
Пример #10
0
 def test_nested_inputs(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(lambda x: 2. * x[0])
   root.f([constant_op.constant(1.)])
   to_export = root.f.get_concrete_function(
       [constant_op.constant(1.), constant_op.constant(2.)])
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "non-unique argument names"):
     export.export(root, export_dir, to_export)
Пример #11
0
 def test_no_reference_cycles(self):
     x = constant_op.constant([[3., 4.]])
     y = constant_op.constant([2.])
     self._model(x, y)
     if sys.version_info[0] < 3:
         # TODO(allenl): debug reference cycles in Python 2.x
         self.skipTest(
             "This test only works in Python 3+. Reference cycles are "
             "created in older Python versions.")
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     export.export(self._model, export_dir, self._model.call)
Пример #12
0
 def test_variable(self):
     root = tracking.Checkpointable()
     root.v1 = variables.Variable(3.)
     root.v2 = variables.Variable(2.)
     root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
     root.f(constant_op.constant(1.))
     to_export = root.f.get_concrete_function(constant_op.constant(1.))
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     export.export(root, export_dir, to_export)
     self.assertAllEqual({"output_0": 12.},
                         self._import_and_infer(export_dir, {"x": 2.}))
Пример #13
0
 def test_method_export_signature(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(
       lambda x: 2. * x,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
   root.f(constant_op.constant(1.))
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   export.export(root, export_dir, root.f)
   self.assertEqual(
       {"output_0": 2.},
       self._import_and_infer(export_dir, {"x": 1.}))
Пример #14
0
 def test_single_method_default_signature(self):
     model = _ModelWithOptimizer()
     x = constant_op.constant([[3., 4.]])
     y = constant_op.constant([2.])
     model(x, y)
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     export.export(model, export_dir)
     self.assertIn(
         "loss",
         self._import_and_infer(export_dir, {
             "x": [[3., 4.]],
             "y": [2.]
         }))
Пример #15
0
 def test_method_export_concrete(self):
     root = tracking.Checkpointable()
     root.f = def_function.function(lambda z: {"out": 2. * z})
     root.f(constant_op.constant(1.))
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     export.export(
         root, export_dir, {
             "non_default_key":
             root.f.get_concrete_function(
                 tensor_spec.TensorSpec(None, dtypes.float32))
         })
     self.assertEqual({"out": 2.},
                      self._import_and_infer(
                          export_dir, {"z": 1.},
                          signature_key="non_default_key"))
Пример #16
0
 def test_optimizer(self):
     x = constant_op.constant([[3., 4.]])
     y = constant_op.constant([2.])
     model = _ModelWithOptimizer()
     first_loss = model(x, y)
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     export.export(model, export_dir, model.call)
     second_loss = model(x, y)
     self.assertNotEqual(first_loss, second_loss)
     self.assertAllClose(
         second_loss,
         self._import_and_infer(export_dir, {
             "x": [[3., 4.]],
             "y": [2.]
         }))
Пример #17
0
 def test_method_export_concrete(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(
       lambda z: {"out": 2. * z})
   root.f(constant_op.constant(1.))
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   export.export(
       root,
       export_dir,
       {"non_default_key": root.f.get_concrete_function(
           tensor_spec.TensorSpec(None, dtypes.float32))})
   self.assertEqual(
       {"out": 2.},
       self._import_and_infer(
           export_dir, {"z": 1.}, signature_key="non_default_key"))
Пример #18
0
 def test_trivial_export_exception(self):
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegexp(ValueError, "signature"):
         export.export(tracking.Checkpointable(), export_dir)