Beispiel #1
0
  def test_callable(self):
    class M1(tracking.Checkpointable):

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
      def __call__(self, x):
        return x

    root = tracking.Checkpointable()
    root.m1 = M1()
    root.m2 = tracking.Checkpointable()
    root.m2.__call__ = def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
            lambda x: x*3.0)
    imported = self.cycle(root)
    x = constant_op.constant(1.0)

    self.assertTrue(callable(imported.m1))
    self.assertAllEqual(root.m1(x), imported.m1(x))

    # Note: `root.m2` was not callable since `__call__` attribute was set
    # into the instance and not on the class. But after a serialization cycle
    # that starts to work.
    self.assertTrue(callable(imported.m2))
    self.assertAllEqual(root.m2.__call__(x), imported.m2(x))

    # Verify that user objects without `__call__` attribute are not callable.
    self.assertFalse(callable(imported))
 def testListWrapperBasic(self):
     # _ListWrapper, unlike List, compares like the built-in list type (since it
     # is used to automatically replace lists).
     a = tracking.Checkpointable()
     b = tracking.Checkpointable()
     self.assertEqual([a, a], [a, a])
     self.assertEqual(data_structures._ListWrapper([a, a]),
                      data_structures._ListWrapper([a, a]))
     self.assertEqual([a, a], data_structures._ListWrapper([a, a]))
     self.assertEqual(data_structures._ListWrapper([a, a]), [a, a])
     self.assertNotEqual([a, a], [b, a])
     self.assertNotEqual(data_structures._ListWrapper([a, a]),
                         data_structures._ListWrapper([b, a]))
     self.assertNotEqual([a, a], data_structures._ListWrapper([b, a]))
     self.assertLess([a], [a, b])
     self.assertLess(data_structures._ListWrapper([a]),
                     data_structures._ListWrapper([a, b]))
     self.assertLessEqual([a], [a, b])
     self.assertLessEqual(data_structures._ListWrapper([a]),
                          data_structures._ListWrapper([a, b]))
     self.assertGreater([a, b], [a])
     self.assertGreater(data_structures._ListWrapper([a, b]),
                        data_structures._ListWrapper([a]))
     self.assertGreaterEqual([a, b], [a])
     self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                             data_structures._ListWrapper([a]))
     self.assertEqual([a], data_structures._ListWrapper([a]))
     self.assertEqual([a], list(data_structures.List([a])))
     self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
     self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
     self.assertIsInstance(data_structures._ListWrapper([a]), list)
Beispiel #3
0
 def test_structure_import(self):
   root = tracking.Checkpointable()
   root.dep_one = tracking.Checkpointable()
   root.dep_two = tracking.Checkpointable()
   root.dep_two.dep = tracking.Checkpointable()
   root.dep_three = root.dep_two.dep
   imported = self.cycle(root)
   self.assertIs(imported.dep_three, imported.dep_two.dep)
   self.assertIsNot(imported.dep_one, imported.dep_two)
Beispiel #4
0
 def testMutationDirtiesList(self):
   a = tracking.Checkpointable()
   b = tracking.Checkpointable()
   a.l = [b]
   c = tracking.Checkpointable()
   a.l.insert(0, c)
   checkpoint = util.Checkpoint(a=a)
   with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
 def testNoDependency(self):
     root = tracking.Checkpointable()
     hasdep = tracking.Checkpointable()
     root.hasdep = hasdep
     nodep = tracking.Checkpointable()
     root.nodep = tracking.NoDependency(nodep)
     self.assertEqual(1, len(root._checkpoint_dependencies))
     self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
     self.assertIs(root.hasdep, hasdep)
     self.assertIs(root.nodep, nodep)
Beispiel #6
0
 def testOutOfBandEditDirtiesList(self):
   a = tracking.Checkpointable()
   b = tracking.Checkpointable()
   held_reference = [b]
   a.l = held_reference
   c = tracking.Checkpointable()
   held_reference.append(c)
   checkpoint = util.Checkpoint(a=a)
   with self.assertRaisesRegexp(ValueError, "The wrapped list was modified"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
Beispiel #7
0
 def testShallowCopyCheckpointable(self):
   original = tracking.Checkpointable()
   original_sub = tracking.Checkpointable()
   original.a = [[1.]]
   original.b = {"a": original_sub}
   shallow_copied = copy.copy(original)
   self.assertIs(original_sub, shallow_copied.b["a"])
   self.assertIsNot(original, shallow_copied)
   self.assertEqual([[1.]], shallow_copied.a)
   shallow_deps = util.list_objects(shallow_copied)
   self.assertIn(shallow_copied.a, shallow_deps)
   self.assertIn(shallow_copied.b, shallow_deps)
   self.assertIn(shallow_copied.b["a"], shallow_deps)
 def testListBasic(self):
     a = tracking.Checkpointable()
     b = tracking.Checkpointable()
     a.l = [b]
     c = tracking.Checkpointable()
     a.l.append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     direct_a_dep, = a._checkpoint_dependencies
     self.assertEqual("l", direct_a_dep.name)
     self.assertIn(b, direct_a_dep.ref)
     self.assertIn(c, direct_a_dep.ref)
Beispiel #9
0
  def test_chain_callable(self):
    func = def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
            lambda x: x*3.0)
    root = tracking.Checkpointable()
    root.__call__ = tracking.Checkpointable()
    root.__call__.__call__ = tracking.Checkpointable()
    root.__call__.__call__.__call__ = func

    imported = self.cycle(root)
    self.assertTrue(callable(imported))
    x = constant_op.constant(1.0)
    self.assertAllEqual(imported(x).numpy(), 3.0)
Beispiel #10
0
 def test_structure_import(self):
     root = tracking.Checkpointable()
     root.f = def_function.function(
         lambda x: 2. * x,
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
     root.dep_one = tracking.Checkpointable()
     root.dep_two = tracking.Checkpointable()
     root.dep_two.dep = tracking.Checkpointable()
     root.dep_three = root.dep_two.dep
     imported = self.cycle(root)
     self.assertIs(imported.dep_three, imported.dep_two.dep)
     self.assertIsNot(imported.dep_one, imported.dep_two)
     self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
 def testMultipleAssignment(self):
     root = tracking.Checkpointable()
     root.leaf = tracking.Checkpointable()
     root.leaf = root.leaf
     duplicate_name_dep = tracking.Checkpointable()
     with self.assertRaises(ValueError):
         root._track_checkpointable(duplicate_name_dep, name="leaf")
     # No error; we're overriding __setattr__, so we can't really stop people
     # from doing this while maintaining backward compatibility.
     root.leaf = duplicate_name_dep
     root._track_checkpointable(duplicate_name_dep,
                                name="leaf",
                                overwrite=True)
Beispiel #12
0
 def test_structure_import(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(
       lambda x: 2. * x,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
   root.dep_one = tracking.Checkpointable()
   root.dep_two = tracking.Checkpointable()
   root.dep_two.dep = tracking.Checkpointable()
   root.dep_three = root.dep_two.dep
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(root, save_dir)
   imported = load.load(save_dir)
   self.assertIs(imported.dep_three, imported.dep_two.dep)
   self.assertIsNot(imported.dep_one, imported.dep_two)
Beispiel #13
0
 def testMultipleAssignment(self):
   root = tracking.Checkpointable()
   root.leaf = tracking.Checkpointable()
   root.leaf = root.leaf
   duplicate_name_dep = tracking.Checkpointable()
   with self.assertRaisesRegexp(ValueError, "already declared"):
     root._track_checkpointable(duplicate_name_dep, name="leaf")
   # No error; we're overriding __setattr__, so we can't really stop people
   # from doing this while maintaining backward compatibility.
   root.leaf = duplicate_name_dep
   root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
   self.assertIs(duplicate_name_dep, root._lookup_dependency("leaf"))
   (_, dep_object), = root._checkpoint_dependencies
   self.assertIs(duplicate_name_dep, dep_object)
Beispiel #14
0
 def testDeepCopyCheckpointable(self):
   original = tracking.Checkpointable()
   original_sub = tracking.Checkpointable()
   original.a = [[1.]]
   original.b = {"a": original_sub}
   deep_copied = copy.deepcopy(original)
   self.assertIsNot(original, deep_copied)
   self.assertIsNot(original_sub, deep_copied.b["a"])
   self.assertEqual([[1.]], deep_copied.a)
   self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable)
   deps = util.list_objects(deep_copied)
   self.assertIn(deep_copied.a, deps)
   self.assertIn(deep_copied.b, deps)
   self.assertIn(deep_copied.b["a"], deps)
   self.assertNotIn(original_sub, deps)
Beispiel #15
0
  def test_soft_matching(self):

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)])
    def func(x):
      return 2 * x

    root = tracking.Checkpointable()
    root.f = func

    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
    self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy())

    self.assertEqual(
        1, len(function_serialization.list_all_concrete_functions(root.f)))

    imported = self.cycle(root)

    with self.assertRaises(AssertionError):
      # We cannot call the function with a constant of shape ().
      self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())

    # TODO(vbardiovsky): When classes are revived with input_signatures, we
    # should also check that the calls below are not generating any more
    # concrete functions.
    self.assertAllEqual([2, 4, 6, 8],
                        imported.f(constant_op.constant([1, 2, 3, 4])).numpy())
    self.assertAllEqual([2, 4, 6],
                        imported.f(constant_op.constant([1, 2, 3])).numpy())
Beispiel #16
0
  def test_structured_inputs(self):

    def func(x, training=True):
      # x is a nested structure, we care about one particular tensor.
      _, (a, b) = x
      if training:
        return 2 * a["a"] + b
      else:
        return 7

    root = tracking.Checkpointable()
    root.f = def_function.function(func)

    x = constant_op.constant(10)
    y = constant_op.constant(11)

    input1 = [6, ({"a": x}, y)]
    input2 = [7, ({"a": x}, y)]  # Not compatible with input1 signature.
    input3 = [6, ({"a": y}, x)]  # Compatible with input1 signature.

    # Note: by only calling f(input1) before serialization, only inputs with
    # matching signature will be valid on the loaded model.
    self.assertEqual(31, root.f(input1).numpy())

    imported = self.cycle(root)

    with self.assertRaisesRegexp(AssertionError,
                                 "Could not find matching function to call.*"):
      imported.f(input2)

    self.assertEqual(31, imported.f(input1).numpy())
    self.assertEqual(32, imported.f(input3).numpy())
Beispiel #17
0
    def testNames(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

        x1 = resource_variable_ops.ResourceVariable(2.)
        x2 = resource_variable_ops.ResourceVariable(3.)
        x3 = resource_variable_ops.ResourceVariable(4.)
        y = resource_variable_ops.ResourceVariable(5.)
        slots = containers.UniqueNameTracker()
        slots.track(x1, "x")
        slots.track(x2, "x")
        slots.track(x3, "x_1")
        slots.track(y, "y")
        self.evaluate(
            (x1.initializer, x2.initializer, x3.initializer, y.initializer))
        save_root = util.Checkpoint(slots=slots)
        save_path = save_root.save(checkpoint_prefix)

        restore_slots = tracking.Checkpointable()
        restore_root = util.Checkpoint(slots=restore_slots)
        status = restore_root.restore(save_path)
        restore_slots.x = resource_variable_ops.ResourceVariable(0.)
        restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.)
        restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.)
        restore_slots.y = resource_variable_ops.ResourceVariable(0.)
        status.assert_consumed().run_restore_ops()
        self.assertEqual(2., self.evaluate(restore_slots.x))
        self.assertEqual(3., self.evaluate(restore_slots.x_1))
        self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
        self.assertEqual(5., self.evaluate(restore_slots.y))
Beispiel #18
0
  def test_structured_output(self):

    # Use fields with non-alphabetical order
    named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])

    def func(input1, input2):
      named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
      return [named_tuple, input2, {"x": 0.5}]

    root = tracking.Checkpointable()
    root.f = def_function.function(func)

    result = root.f(constant_op.constant(2), constant_op.constant(3))

    self.assertEqual(5, result[0].a.numpy())
    self.assertEqual(6, result[0].b.numpy())
    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
    self.assertEqual(3, result[1].numpy())
    self.assertEqual(0.5, result[2]["x"].numpy())

    imported = self.cycle(root)

    result = imported.f(constant_op.constant(2), constant_op.constant(5))
    self.assertEqual(7, result[0].a.numpy())
    self.assertEqual(10, result[0].b.numpy())
    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
    self.assertEqual(5, result[1].numpy())
    self.assertEqual(0.5, result[2]["x"].numpy())
Beispiel #19
0
 def test_variables(self):
     root = tracking.Checkpointable()
     root.v1 = variables.Variable(1.)
     root.v2 = variables.Variable(2.)
     imported = self.cycle(root)
     self.assertEquals(imported.v1.numpy(), 1.0)
     self.assertEquals(imported.v2.numpy(), 2.0)
Beispiel #20
0
 def test_single_function_default_signature(self):
     model = tracking.Checkpointable()
     model.f = def_function.function(lambda: 3., input_signature=())
     model.f()
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     save.save(model, save_dir)
     self.assertAllClose({"output_0": 3.}, _import_and_infer(save_dir, {}))
Beispiel #21
0
 def test_nested_outputs(self):
     root = tracking.Checkpointable()
     root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
     root.f(constant_op.constant(1.))
     to_export = root.f.get_concrete_function(constant_op.constant(1.))
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegexp(ValueError, "non-flat outputs"):
         export.export(root, export_dir, to_export)
Beispiel #22
0
 def test_non_concrete_error(self):
     root = tracking.Checkpointable()
     root.f = def_function.function(lambda x: 2. * x)
     root.f(constant_op.constant(1.))
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegexp(
             ValueError, "must be converted to concrete functions"):
         export.export(root, export_dir, root.f)
Beispiel #23
0
 def test_dedup_assets(self):
   vocab = self._make_asset("contents")
   root = tracking.Checkpointable()
   root.asset1 = tracking.TrackableAsset(vocab)
   root.asset2 = tracking.TrackableAsset(vocab)
   imported = self.cycle(root)
   self.assertEqual(imported.asset1.asset_path.numpy(),
                    imported.asset2.asset_path.numpy())
 def _get_checkpoint_name(self, name):
   root = tracking.Checkpointable()
   checkpointable_utils.add_variable(
       root, name=name, shape=[1, 2], dtype=dtypes.float64)
   (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
       root, saveables_cache=None)
   with ops.name_scope("root/" + named_variable.name):
     pass  # Make sure we can use this as an op name if we prefix it.
   return named_variable.name
Beispiel #25
0
 def test_dict(self):
     root = tracking.Checkpointable()
     root.variables = dict(a=variables.Variable(1.))
     root.variables["b"] = variables.Variable(2.)
     root.variables["c"] = 1
     imported = self.cycle(root)
     self.assertEqual(1., imported.variables["a"].numpy())
     self.assertEqual(2., imported.variables["b"].numpy())
     self.assertEqual(set(["a", "b"]), set(imported.variables.keys()))
Beispiel #26
0
 def testDictWrapperBadKeys(self):
   a = tracking.Checkpointable()
   a.d = {}
   a.d[1] = data_structures.List()
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   with self.assertRaisesRegexp(ValueError, "non-string key"):
     model.save_weights(save_path)
Beispiel #27
0
 def test_variables(self):
   root = tracking.Checkpointable()
   root.v1 = variables.Variable(1., trainable=True)
   root.v2 = variables.Variable(2., trainable=False)
   imported = self.cycle(root)
   self.assertEquals(imported.v1.numpy(), 1.0)
   self.assertTrue(imported.v1.trainable)
   self.assertEquals(imported.v2.numpy(), 2.0)
   self.assertFalse(imported.v2.trainable)
Beispiel #28
0
 def test_method_export_signature(self):
     root = tracking.Checkpointable()
     root.f = def_function.function(
         lambda x: 2. * x,
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
     root.f(constant_op.constant(1.))
     export_dir = os.path.join(self.get_temp_dir(), "saved_model")
     export.export(root, export_dir, root.f)
     self.assertEqual({"output_0": 2.},
                      self._import_and_infer(export_dir, {"x": 1.}))
Beispiel #29
0
 def testDictWrapperNoDependency(self):
   a = tracking.Checkpointable()
   a.d = data_structures.NoDependency({})
   a.d[1] = [3]
   self.assertEqual([a], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
Beispiel #30
0
 def test_capture_variables(self):
   root = tracking.Checkpointable()
   root.weights = variables.Variable(2.)
   root.f = def_function.function(
       lambda x: root.weights * x,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
   imported = self.cycle(root)
   self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
   imported.weights.assign(4.0)
   self.assertEqual(8., imported.f(constant_op.constant(2.)).numpy())