Esempio n. 1
0
    def test_explicit_input_signature(self, cycles):
        @def_function.function(
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        def func(x):
            return 2 * x

        root = tracking.AutoCheckpointable()
        root.f = func

        imported = self.cycle(root, cycles)
        self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy())
Esempio n. 2
0
    def test_unused_asset(self):
        root = tracking.AutoCheckpointable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.asset = tracking.TrackableAsset(self._vocab_path)

        export_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, export_dir)
        self.assertAllClose({"output_0": [0.2]},
                            _import_and_infer(export_dir, {"x": [0.1]}))
Esempio n. 3
0
 def test_capture_assets(self, cycles):
     root = tracking.AutoCheckpointable()
     root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
     root.f = def_function.function(lambda: root.vocab.asset_path,
                                    input_signature=[])
     imported = self.cycle(root, cycles)
     original_output = root.f().numpy()
     imported_output = imported.f().numpy()
     self.assertNotEqual(original_output, imported_output)
     with open(imported_output, "r") as f:
         self.assertEqual("contents", f.read())
Esempio n. 4
0
 def test_method_save_signature(self):
   root = tracking.AutoCheckpointable()
   root.f = def_function.function(
       lambda x: 2. * x,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
   root.f(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(root, save_dir, root.f)
   self.assertEqual(
       {"output_0": 2.},
       _import_and_infer(save_dir, {"x": 1.}))
Esempio n. 5
0
 def test_overwritten_signatures_error(self, cycles):
     exported = tracking.AutoCheckpointable()
     exported.f = def_function.function(lambda: constant_op.constant(1.))
     imported = self.cycle(
         exported,
         cycles,
         signatures={"key": exported.f.get_concrete_function()})
     self.assertEqual(1., imported.signatures["key"]()["output_0"].numpy())
     imported.signatures = {"key1": imported.signatures["key"]}
     with self.assertRaisesRegexp(ValueError, "signatures"):
         save.save(imported, tempfile.mkdtemp(prefix=self.get_temp_dir()))
Esempio n. 6
0
 def test_variable(self):
     root = tracking.AutoCheckpointable()
     root.v1 = variables.Variable(3.)
     root.v2 = variables.Variable(2.)
     root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
     root.f(constant_op.constant(1.))
     to_save = root.f.get_concrete_function(constant_op.constant(1.))
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     save.save(root, save_dir, to_save)
     self.assertAllEqual({"output_0": 12.},
                         _import_and_infer(save_dir, {"x": 2.}))
Esempio n. 7
0
  def test_concrete_function_no_signature(self):
    @def_function.function
    def func(x):
      return 2 * x

    root = tracking.AutoCheckpointable()
    root.f = func.get_concrete_function(constant_op.constant([1]))
    self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy())
    imported = self.cycle(root)
    self.assertAllEqual([6],
                        imported.f(constant_op.constant([3])).numpy())
Esempio n. 8
0
    def test_concrete_function_no_signature(self, cycles):
        @def_function.function
        def func(x):
            return 2 * x

        root = tracking.AutoCheckpointable()
        root.f = func.get_concrete_function(constant_op.constant([1]))
        self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy())
        # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
        imported = self.cycle(root, cycles, signatures={})
        self.assertAllEqual([6], imported.f(constant_op.constant([3])).numpy())
Esempio n. 9
0
    def test_perserve_argspec(self, cycles):
        def f(a, b, c):  # pylint: disable=unused-argument
            return None

        original_fullargspec = tf_inspect.getfullargspec(f)

        root = tracking.AutoCheckpointable()
        root.f = def_function.function(f)
        imported = self.cycle(root, cycles)

        restored_fullargspec = tf_inspect.getfullargspec(imported.f)
        self.assertEqual(original_fullargspec, restored_fullargspec)
Esempio n. 10
0
 def testAssertions(self):
     a = tracking.AutoCheckpointable()
     a.l = {"k": [numpy.zeros([2, 2])]}
     self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}),
                         nest.flatten(a.l))
     self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l)
     nest.map_structure(self.assertAllClose, a.l,
                        {"k": [numpy.zeros([2, 2])]})
     a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]}
     self.assertAllClose(
         {"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]},
         self.evaluate(a.tensors))
Esempio n. 11
0
 def _get_checkpoint_name(self, name):
     root = tracking.AutoCheckpointable()
     checkpointable_utils.add_variable(root,
                                       name=name,
                                       shape=[1, 2],
                                       dtype=dtypes.float64)
     (named_variable,
      ), _, _ = checkpointable_utils._serialize_object_graph(
          root, saveables_cache=None)
     with ops.name_scope("root/" + named_variable.name):
         pass  # Make sure we can use this as an op name if we prefix it.
     return named_variable.name
Esempio n. 12
0
  def test_explicit_save_signature(self):
    @def_function.function
    def func(x):
      return 2 * x

    root = tracking.AutoCheckpointable()
    root.f = func

    imported = self.cycle(
        root, {"f": root.f.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32))})
    self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy())
Esempio n. 13
0
    def test_nested_functions(self):
        f = def_function.function(
            lambda x: x * 2.0,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        g = def_function.function(
            lambda x: f(x) + 1.0,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

        root = tracking.AutoCheckpointable()
        root.g = g
        imported = self.cycle(root)
        imported.g(constant_op.constant([1.0]))
Esempio n. 14
0
 def test_nested_inputs(self):
   root = tracking.AutoCheckpointable()
   root.f = def_function.function(
       lambda x: 2. * x[0],
       input_signature=([tensor_spec.TensorSpec(None, dtypes.float32),
                         tensor_spec.TensorSpec(None, dtypes.float32)],))
   root.f([constant_op.constant(1.), constant_op.constant(1.)])
   # Concrete functions must always have uniquely named Tensor inputs. Save
   # relies on this.
   with self.assertRaisesRegexp(
       ValueError, "two arguments named 'x'"):
     root.f.get_concrete_function()
Esempio n. 15
0
 def testNestedLists(self):
     a = tracking.AutoCheckpointable()
     a.l = []
     b = tracking.AutoCheckpointable()
     a.l.append([b])
     c = tracking.AutoCheckpointable()
     a.l[0].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     a.l[0].append(1)
     d = tracking.AutoCheckpointable()
     a.l[0].append(d)
     a_deps = util.list_objects(a)
     self.assertIn(d, a_deps)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertNotIn(1, a_deps)
     e = tracking.AutoCheckpointable()
     f = tracking.AutoCheckpointable()
     a.l1 = [[], [e]]
     a.l1[0].append(f)
     a_deps = util.list_objects(a)
     self.assertIn(e, a_deps)
     self.assertIn(f, a_deps)
     checkpoint = util.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l[0].append(data_structures.NoDependency([]))
     a.l[0][-1].append(5)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     # Dirtying the inner list means the root object is unsaveable.
     a.l[0][1] = 2
     with self.assertRaisesRegexp(ValueError,
                                  "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
 def testHashing(self):
   has_mappings = set([data_structures.Mapping(),
                       data_structures.Mapping()])
   self.assertEqual(2, len(has_mappings))
   self.assertNotIn(data_structures.Mapping(), has_mappings)
   # In contrast to Mapping, dict wrappers are not hashable
   a = tracking.AutoCheckpointable()
   a.d = {}
   self.assertEqual({}, a.d)
   self.assertFalse({} != a.d)  # pylint: disable=g-explicit-bool-comparison
   self.assertNotEqual({1: 2}, a.d)
   with self.assertRaisesRegexp(TypeError, "unhashable"):
     set([a.d])
 def testLayerCollectionWithExternalMutation(self):
     d = {}
     root = tracking.AutoCheckpointable()
     root.wrapper = d
     self.assertEqual([], root.wrapper.layers)
     self.assertEqual([], root.wrapper.trainable_weights)
     layer1 = core.Dense(1)
     layer2 = core.Dense(1)
     d["a"] = layer1
     d["b"] = layer2
     self.assertEqual([layer1, layer2], root.wrapper.layers)
     # The layers have still not created variables
     self.assertEqual([], root.wrapper.trainable_weights)
Esempio n. 18
0
 def test_asset_path_returned(self):
     root = tracking.AutoCheckpointable()
     root.path = tracking.TrackableAsset(self._vocab_path)
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     root.get_asset = def_function.function(lambda: root.path.asset_path)
     save.save(root,
               save_dir,
               signatures=root.get_asset.get_concrete_function())
     second_dir = os.path.join(self.get_temp_dir(), "second_dir")
     file_io.rename(save_dir, second_dir)
     imported_path = _import_and_infer(second_dir, {})["output_0"]
     self.assertIn(compat.as_str_any(second_dir),
                   compat.as_str_any(imported_path))
Esempio n. 19
0
  def test_load_in_func_graph(self, cycles):
    root = tracking.AutoCheckpointable()
    root.v1 = variables.Variable(1.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(
        lambda x: root.v2 * x,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

    if cycles > 1:
      root = self.cycle(root, cycles - 1)
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)

    closure = tracking.AutoCheckpointable()
    @def_function.function
    def func(x):
      if not hasattr(closure, "model"):
        closure.model = load.load(path)
      return closure.model.f(x)

    inputs = constant_op.constant(2.)
    self.assertEqual(4.0, func(inputs).numpy())
Esempio n. 20
0
 def test_method_save_concrete(self):
     root = tracking.AutoCheckpointable()
     root.f = def_function.function(lambda z: {"out": 2. * z})
     root.f(constant_op.constant(1.))
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     save.save(
         root, save_dir, {
             "non_default_key":
             root.f.get_concrete_function(
                 tensor_spec.TensorSpec(None, dtypes.float32))
         })
     self.assertEqual({"out": 2.},
                      _import_and_infer(save_dir, {"z": 1.},
                                        signature_key="non_default_key"))
Esempio n. 21
0
 def test_dict(self, cycles):
   root = tracking.AutoCheckpointable()
   root.variables = dict(a=variables.Variable(1.))
   root.variables["b"] = variables.Variable(2.)
   root.variables["c"] = 1
   root.funcs = dict(
       a=def_function.function(lambda: constant_op.constant(100.)))
   root.funcs["conc"] = root.funcs["a"].get_concrete_function()
   imported = self.cycle(root, cycles)
   self.assertEqual(1., imported.variables["a"].numpy())
   self.assertEqual(2., imported.variables["b"].numpy())
   self.assertEqual(set(["a", "b"]), set(imported.variables.keys()))
   self.assertEqual(100., imported.funcs["a"]().numpy())
   self.assertEqual(100., imported.funcs["conc"]().numpy())
Esempio n. 22
0
    def test_nested_functions(self, cycles):
        f = def_function.function(
            lambda x: x * 2.0,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        g = def_function.function(
            lambda x: f(x) + 1.0,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

        root = tracking.AutoCheckpointable()
        root.g = g
        # TODO(vbardiovsky): Enable this test. For this to work, we must ensure that
        # concrete_function._inference_function._graph._functions contains all
        # functions that were on the graph before saving.
        imported = self.cycle(root, 1)
        imported.g(constant_op.constant([1.0]))
Esempio n. 23
0
  def test_revived_concrete_function_tensorspec_kwargs(self, cycles):

    @def_function.function
    def func(*args):
      x, y = args
      return x * (y + 1.)
    root = tracking.AutoCheckpointable()
    root.f = func.get_concrete_function(
        tensor_spec.TensorSpec([], dtypes.float32, name="x"),
        tensor_spec.TensorSpec([], dtypes.float32, name="y"))
    self.assertEqual(8., root.f(y=constant_op.constant(3.),
                                x=constant_op.constant(2.)).numpy())
    imported = self.cycle(root, cycles, signatures={})
    self.assertEqual(8., imported.f(y=constant_op.constant(3.),
                                    x=constant_op.constant(2.)).numpy())
Esempio n. 24
0
  def test_revived_concrete_function_kwargs(self, cycles):

    @def_function.function
    def func(x, y):
      return x * (y + 1.)
    root = tracking.AutoCheckpointable()
    root.f = func.get_concrete_function(
        tensor_spec.TensorSpec([], dtypes.float32),
        tensor_spec.TensorSpec([], dtypes.float32))
    self.assertEqual(8., root.f(y=constant_op.constant(3.),
                                x=constant_op.constant(2.)).numpy())
    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
    imported = self.cycle(root, cycles, signatures={})
    self.assertEqual(8., imported.f(y=constant_op.constant(3.),
                                    x=constant_op.constant(2.)).numpy())
Esempio n. 25
0
    def test_function_and_component(self, cycles):
        @def_function.function
        def func(v):
            return v + 1

        root = tracking.AutoCheckpointable()
        root.func = func
        root.concrete_func = func.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.int32))
        one = constant_op.constant(1)
        self.assertEqual(2, root.func(one).numpy())
        self.assertEqual(2, root.concrete_func(one).numpy())
        imported = self.cycle(root, cycles)
        self.assertEqual(2, imported.func(one).numpy())
        self.assertEqual(2, imported.concrete_func(one).numpy())
 def testListWrapperBasic(self):
   # _ListWrapper, unlike List, compares like the built-in list type (since it
   # is used to automatically replace lists).
   a = tracking.AutoCheckpointable()
   b = tracking.AutoCheckpointable()
   self.assertEqual([a, a],
                    [a, a])
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    data_structures._ListWrapper([a, a]))
   self.assertEqual([a, a],
                    data_structures._ListWrapper([a, a]))
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    [a, a])
   self.assertNotEqual([a, a],
                       [b, a])
   self.assertNotEqual(data_structures._ListWrapper([a, a]),
                       data_structures._ListWrapper([b, a]))
   self.assertNotEqual([a, a],
                       data_structures._ListWrapper([b, a]))
   self.assertLess([a], [a, b])
   self.assertLess(data_structures._ListWrapper([a]),
                   data_structures._ListWrapper([a, b]))
   self.assertLessEqual([a], [a, b])
   self.assertLessEqual(data_structures._ListWrapper([a]),
                        data_structures._ListWrapper([a, b]))
   self.assertGreater([a, b], [a])
   self.assertGreater(data_structures._ListWrapper([a, b]),
                      data_structures._ListWrapper([a]))
   self.assertGreaterEqual([a, b], [a])
   self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                           data_structures._ListWrapper([a]))
   self.assertEqual([a], data_structures._ListWrapper([a]))
   self.assertEqual([a], list(data_structures.List([a])))
   self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
   self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
   self.assertIsInstance(data_structures._ListWrapper([a]), list)
Esempio n. 27
0
    def test_implicit_input_signature(self):
        @def_function.function
        def func(x):
            return 2 * x

        root = tracking.AutoCheckpointable()
        root.f = func

        # Add two traces.
        root.f(constant_op.constant(1.))
        root.f(constant_op.constant(1))

        imported = self.cycle(root)

        self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
        self.assertEqual(14, imported.f(constant_op.constant(7)).numpy())
    def testDictDeepCopy(self):
        root = tracking.AutoCheckpointable()
        orig_dict = {"a": [1.]}
        root.a = orig_dict
        copied = copy.deepcopy(root.a)
        self.assertAllEqual([1.], copied["a"])
        self.assertIsNot(root.a, copied)
        self.assertIsNot(root.a["a"], copied["a"])

        # Dirtiness should be inherited
        util.list_objects(root.a)
        orig_dict["b"] = []
        with self.assertRaises(ValueError):
            util.list_objects(root.a)
        with self.assertRaises(ValueError):
            util.list_objects(copy.deepcopy(root.a))
    def testListDeepCopy(self):
        root = tracking.AutoCheckpointable()
        orig_list = [[1.]]
        root.a = orig_list
        copied = copy.deepcopy(root.a)
        self.assertAllEqual([[1.]], copied)
        self.assertIsNot(root.a, copied)
        self.assertIsNot(root.a[0], copied[0])

        # Dirtiness should be inherited
        util.list_objects(root.a)
        orig_list.append(1.)
        with self.assertRaises(ValueError):
            util.list_objects(root.a)
        with self.assertRaises(ValueError):
            util.list_objects(copy.deepcopy(root.a))
Esempio n. 30
0
  def test_concrete_function_arg_names(self):

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)])
    def func(x):
      return 2 * x

    root = tracking.AutoCheckpointable()
    root.f = func.get_concrete_function()

    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())

    imported = self.cycle(root)

    self.assertAllEqual([2, 4, 6],
                        imported.f(x=constant_op.constant([1, 2, 3])).numpy())