def test_callable(self): class M1(tracking.Checkpointable): @def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) def __call__(self, x): return x root = tracking.Checkpointable() root.m1 = M1() root.m2 = tracking.Checkpointable() root.m2.__call__ = def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])( lambda x: x*3.0) imported = self.cycle(root) x = constant_op.constant(1.0) self.assertTrue(callable(imported.m1)) self.assertAllEqual(root.m1(x), imported.m1(x)) # Note: `root.m2` was not callable since `__call__` attribute was set # into the instance and not on the class. But after a serialization cycle # that starts to work. self.assertTrue(callable(imported.m2)) self.assertAllEqual(root.m2.__call__(x), imported.m2(x)) # Verify that user objects without `__call__` attribute are not callable. self.assertFalse(callable(imported))
def testListWrapperBasic(self): # _ListWrapper, unlike List, compares like the built-in list type (since it # is used to automatically replace lists). a = tracking.Checkpointable() b = tracking.Checkpointable() self.assertEqual([a, a], [a, a]) self.assertEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([a, a])) self.assertEqual([a, a], data_structures._ListWrapper([a, a])) self.assertEqual(data_structures._ListWrapper([a, a]), [a, a]) self.assertNotEqual([a, a], [b, a]) self.assertNotEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([b, a])) self.assertNotEqual([a, a], data_structures._ListWrapper([b, a])) self.assertLess([a], [a, b]) self.assertLess(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertLessEqual([a], [a, b]) self.assertLessEqual(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertGreater([a, b], [a]) self.assertGreater(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertGreaterEqual([a, b], [a]) self.assertGreaterEqual(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertEqual([a], data_structures._ListWrapper([a])) self.assertEqual([a], list(data_structures.List([a]))) self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) self.assertIsInstance(data_structures._ListWrapper([a]), list)
def test_structure_import(self): root = tracking.Checkpointable() root.dep_one = tracking.Checkpointable() root.dep_two = tracking.Checkpointable() root.dep_two.dep = tracking.Checkpointable() root.dep_three = root.dep_two.dep imported = self.cycle(root) self.assertIs(imported.dep_three, imported.dep_two.dep) self.assertIsNot(imported.dep_one, imported.dep_two)
def testMutationDirtiesList(self): a = tracking.Checkpointable() b = tracking.Checkpointable() a.l = [b] c = tracking.Checkpointable() a.l.insert(0, c) checkpoint = util.Checkpoint(a=a) with self.assertRaisesRegexp(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testNoDependency(self): root = tracking.Checkpointable() hasdep = tracking.Checkpointable() root.hasdep = hasdep nodep = tracking.Checkpointable() root.nodep = tracking.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep)
def testOutOfBandEditDirtiesList(self): a = tracking.Checkpointable() b = tracking.Checkpointable() held_reference = [b] a.l = held_reference c = tracking.Checkpointable() held_reference.append(c) checkpoint = util.Checkpoint(a=a) with self.assertRaisesRegexp(ValueError, "The wrapped list was modified"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testShallowCopyCheckpointable(self): original = tracking.Checkpointable() original_sub = tracking.Checkpointable() original.a = [[1.]] original.b = {"a": original_sub} shallow_copied = copy.copy(original) self.assertIs(original_sub, shallow_copied.b["a"]) self.assertIsNot(original, shallow_copied) self.assertEqual([[1.]], shallow_copied.a) shallow_deps = util.list_objects(shallow_copied) self.assertIn(shallow_copied.a, shallow_deps) self.assertIn(shallow_copied.b, shallow_deps) self.assertIn(shallow_copied.b["a"], shallow_deps)
def testListBasic(self): a = tracking.Checkpointable() b = tracking.Checkpointable() a.l = [b] c = tracking.Checkpointable() a.l.append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) direct_a_dep, = a._checkpoint_dependencies self.assertEqual("l", direct_a_dep.name) self.assertIn(b, direct_a_dep.ref) self.assertIn(c, direct_a_dep.ref)
def test_chain_callable(self): func = def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])( lambda x: x*3.0) root = tracking.Checkpointable() root.__call__ = tracking.Checkpointable() root.__call__.__call__ = tracking.Checkpointable() root.__call__.__call__.__call__ = func imported = self.cycle(root) self.assertTrue(callable(imported)) x = constant_op.constant(1.0) self.assertAllEqual(imported(x).numpy(), 3.0)
def test_structure_import(self): root = tracking.Checkpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.dep_one = tracking.Checkpointable() root.dep_two = tracking.Checkpointable() root.dep_two.dep = tracking.Checkpointable() root.dep_three = root.dep_two.dep imported = self.cycle(root) self.assertIs(imported.dep_three, imported.dep_two.dep) self.assertIsNot(imported.dep_one, imported.dep_two) self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
def testMultipleAssignment(self): root = tracking.Checkpointable() root.leaf = tracking.Checkpointable() root.leaf = root.leaf duplicate_name_dep = tracking.Checkpointable() with self.assertRaises(ValueError): root._track_checkpointable(duplicate_name_dep, name="leaf") # No error; we're overriding __setattr__, so we can't really stop people # from doing this while maintaining backward compatibility. root.leaf = duplicate_name_dep root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
def test_structure_import(self): root = tracking.Checkpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.dep_one = tracking.Checkpointable() root.dep_two = tracking.Checkpointable() root.dep_two.dep = tracking.Checkpointable() root.dep_three = root.dep_two.dep save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir) imported = load.load(save_dir) self.assertIs(imported.dep_three, imported.dep_two.dep) self.assertIsNot(imported.dep_one, imported.dep_two)
def testMultipleAssignment(self): root = tracking.Checkpointable() root.leaf = tracking.Checkpointable() root.leaf = root.leaf duplicate_name_dep = tracking.Checkpointable() with self.assertRaisesRegexp(ValueError, "already declared"): root._track_checkpointable(duplicate_name_dep, name="leaf") # No error; we're overriding __setattr__, so we can't really stop people # from doing this while maintaining backward compatibility. root.leaf = duplicate_name_dep root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) self.assertIs(duplicate_name_dep, root._lookup_dependency("leaf")) (_, dep_object), = root._checkpoint_dependencies self.assertIs(duplicate_name_dep, dep_object)
def testDeepCopyCheckpointable(self): original = tracking.Checkpointable() original_sub = tracking.Checkpointable() original.a = [[1.]] original.b = {"a": original_sub} deep_copied = copy.deepcopy(original) self.assertIsNot(original, deep_copied) self.assertIsNot(original_sub, deep_copied.b["a"]) self.assertEqual([[1.]], deep_copied.a) self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable) deps = util.list_objects(deep_copied) self.assertIn(deep_copied.a, deps) self.assertIn(deep_copied.b, deps) self.assertIn(deep_copied.b["a"], deps) self.assertNotIn(original_sub, deps)
def test_soft_matching(self): @def_function.function( input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) def func(x): return 2 * x root = tracking.Checkpointable() root.f = func self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy()) self.assertEqual( 1, len(function_serialization.list_all_concrete_functions(root.f))) imported = self.cycle(root) with self.assertRaises(AssertionError): # We cannot call the function with a constant of shape (). self.assertEqual(7, imported.f(constant_op.constant(2)).numpy()) # TODO(vbardiovsky): When classes are revived with input_signatures, we # should also check that the calls below are not generating any more # concrete functions. self.assertAllEqual([2, 4, 6, 8], imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) self.assertAllEqual([2, 4, 6], imported.f(constant_op.constant([1, 2, 3])).numpy())
def test_structured_inputs(self): def func(x, training=True): # x is a nested structure, we care about one particular tensor. _, (a, b) = x if training: return 2 * a["a"] + b else: return 7 root = tracking.Checkpointable() root.f = def_function.function(func) x = constant_op.constant(10) y = constant_op.constant(11) input1 = [6, ({"a": x}, y)] input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature. input3 = [6, ({"a": y}, x)] # Compatible with input1 signature. # Note: by only calling f(input1) before serialization, only inputs with # matching signature will be valid on the loaded model. self.assertEqual(31, root.f(input1).numpy()) imported = self.cycle(root) with self.assertRaisesRegexp(AssertionError, "Could not find matching function to call.*"): imported.f(input2) self.assertEqual(31, imported.f(input1).numpy()) self.assertEqual(32, imported.f(input3).numpy())
def testNames(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") x1 = resource_variable_ops.ResourceVariable(2.) x2 = resource_variable_ops.ResourceVariable(3.) x3 = resource_variable_ops.ResourceVariable(4.) y = resource_variable_ops.ResourceVariable(5.) slots = containers.UniqueNameTracker() slots.track(x1, "x") slots.track(x2, "x") slots.track(x3, "x_1") slots.track(y, "y") self.evaluate( (x1.initializer, x2.initializer, x3.initializer, y.initializer)) save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) restore_slots = tracking.Checkpointable() restore_root = util.Checkpoint(slots=restore_slots) status = restore_root.restore(save_path) restore_slots.x = resource_variable_ops.ResourceVariable(0.) restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) restore_slots.y = resource_variable_ops.ResourceVariable(0.) status.assert_consumed().run_restore_ops() self.assertEqual(2., self.evaluate(restore_slots.x)) self.assertEqual(3., self.evaluate(restore_slots.x_1)) self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) self.assertEqual(5., self.evaluate(restore_slots.y))
def test_structured_output(self): # Use fields with non-alphabetical order named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"]) def func(input1, input2): named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2) return [named_tuple, input2, {"x": 0.5}] root = tracking.Checkpointable() root.f = def_function.function(func) result = root.f(constant_op.constant(2), constant_op.constant(3)) self.assertEqual(5, result[0].a.numpy()) self.assertEqual(6, result[0].b.numpy()) self.assertEqual(["b", "a"], list(result[0]._asdict().keys())) self.assertEqual(3, result[1].numpy()) self.assertEqual(0.5, result[2]["x"].numpy()) imported = self.cycle(root) result = imported.f(constant_op.constant(2), constant_op.constant(5)) self.assertEqual(7, result[0].a.numpy()) self.assertEqual(10, result[0].b.numpy()) self.assertEqual(["b", "a"], list(result[0]._asdict().keys())) self.assertEqual(5, result[1].numpy()) self.assertEqual(0.5, result[2]["x"].numpy())
def test_variables(self): root = tracking.Checkpointable() root.v1 = variables.Variable(1.) root.v2 = variables.Variable(2.) imported = self.cycle(root) self.assertEquals(imported.v1.numpy(), 1.0) self.assertEquals(imported.v2.numpy(), 2.0)
def test_single_function_default_signature(self): model = tracking.Checkpointable() model.f = def_function.function(lambda: 3., input_signature=()) model.f() save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(model, save_dir) self.assertAllClose({"output_0": 3.}, _import_and_infer(save_dir, {}))
def test_nested_outputs(self): root = tracking.Checkpointable() root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x))) root.f(constant_op.constant(1.)) to_export = root.f.get_concrete_function(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp(ValueError, "non-flat outputs"): export.export(root, export_dir, to_export)
def test_non_concrete_error(self): root = tracking.Checkpointable() root.f = def_function.function(lambda x: 2. * x) root.f(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "must be converted to concrete functions"): export.export(root, export_dir, root.f)
def test_dedup_assets(self): vocab = self._make_asset("contents") root = tracking.Checkpointable() root.asset1 = tracking.TrackableAsset(vocab) root.asset2 = tracking.TrackableAsset(vocab) imported = self.cycle(root) self.assertEqual(imported.asset1.asset_path.numpy(), imported.asset2.asset_path.numpy())
def _get_checkpoint_name(self, name): root = tracking.Checkpointable() checkpointable_utils.add_variable( root, name=name, shape=[1, 2], dtype=dtypes.float64) (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( root, saveables_cache=None) with ops.name_scope("root/" + named_variable.name): pass # Make sure we can use this as an op name if we prefix it. return named_variable.name
def test_dict(self): root = tracking.Checkpointable() root.variables = dict(a=variables.Variable(1.)) root.variables["b"] = variables.Variable(2.) root.variables["c"] = 1 imported = self.cycle(root) self.assertEqual(1., imported.variables["a"].numpy()) self.assertEqual(2., imported.variables["b"].numpy()) self.assertEqual(set(["a", "b"]), set(imported.variables.keys()))
def testDictWrapperBadKeys(self): a = tracking.Checkpointable() a.d = {} a.d[1] = data_structures.List() model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") with self.assertRaisesRegexp(ValueError, "non-string key"): model.save_weights(save_path)
def test_variables(self): root = tracking.Checkpointable() root.v1 = variables.Variable(1., trainable=True) root.v2 = variables.Variable(2., trainable=False) imported = self.cycle(root) self.assertEquals(imported.v1.numpy(), 1.0) self.assertTrue(imported.v1.trainable) self.assertEquals(imported.v2.numpy(), 2.0) self.assertFalse(imported.v2.trainable)
def test_method_export_signature(self): root = tracking.Checkpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.f(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") export.export(root, export_dir, root.f) self.assertEqual({"output_0": 2.}, self._import_and_infer(export_dir, {"x": 1.}))
def testDictWrapperNoDependency(self): a = tracking.Checkpointable() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def test_capture_variables(self): root = tracking.Checkpointable() root.weights = variables.Variable(2.) root.f = def_function.function( lambda x: root.weights * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) imported = self.cycle(root) self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy()) imported.weights.assign(4.0) self.assertEqual(8., imported.f(constant_op.constant(2.)).numpy())