def test_explicit_input_signature(self, cycles): @def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) def func(x): return 2 * x root = tracking.AutoCheckpointable() root.f = func imported = self.cycle(root, cycles) self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy())
def test_unused_asset(self): root = tracking.AutoCheckpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.asset = tracking.TrackableAsset(self._vocab_path) export_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, export_dir) self.assertAllClose({"output_0": [0.2]}, _import_and_infer(export_dir, {"x": [0.1]}))
def test_capture_assets(self, cycles): root = tracking.AutoCheckpointable() root.vocab = tracking.TrackableAsset(self._make_asset("contents")) root.f = def_function.function(lambda: root.vocab.asset_path, input_signature=[]) imported = self.cycle(root, cycles) original_output = root.f().numpy() imported_output = imported.f().numpy() self.assertNotEqual(original_output, imported_output) with open(imported_output, "r") as f: self.assertEqual("contents", f.read())
def test_method_save_signature(self): root = tracking.AutoCheckpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.f(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir, root.f) self.assertEqual( {"output_0": 2.}, _import_and_infer(save_dir, {"x": 1.}))
def test_overwritten_signatures_error(self, cycles): exported = tracking.AutoCheckpointable() exported.f = def_function.function(lambda: constant_op.constant(1.)) imported = self.cycle( exported, cycles, signatures={"key": exported.f.get_concrete_function()}) self.assertEqual(1., imported.signatures["key"]()["output_0"].numpy()) imported.signatures = {"key1": imported.signatures["key"]} with self.assertRaisesRegexp(ValueError, "signatures"): save.save(imported, tempfile.mkdtemp(prefix=self.get_temp_dir()))
def test_variable(self): root = tracking.AutoCheckpointable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) root.f(constant_op.constant(1.)) to_save = root.f.get_concrete_function(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir, to_save) self.assertAllEqual({"output_0": 12.}, _import_and_infer(save_dir, {"x": 2.}))
def test_concrete_function_no_signature(self): @def_function.function def func(x): return 2 * x root = tracking.AutoCheckpointable() root.f = func.get_concrete_function(constant_op.constant([1])) self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy()) imported = self.cycle(root) self.assertAllEqual([6], imported.f(constant_op.constant([3])).numpy())
def test_concrete_function_no_signature(self, cycles): @def_function.function def func(x): return 2 * x root = tracking.AutoCheckpointable() root.f = func.get_concrete_function(constant_op.constant([1])) self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy()) # TODO(andresp): Fix exporting of loaded concrete functions as signatures. imported = self.cycle(root, cycles, signatures={}) self.assertAllEqual([6], imported.f(constant_op.constant([3])).numpy())
def test_perserve_argspec(self, cycles): def f(a, b, c): # pylint: disable=unused-argument return None original_fullargspec = tf_inspect.getfullargspec(f) root = tracking.AutoCheckpointable() root.f = def_function.function(f) imported = self.cycle(root, cycles) restored_fullargspec = tf_inspect.getfullargspec(imported.f) self.assertEqual(original_fullargspec, restored_fullargspec)
def testAssertions(self): a = tracking.AutoCheckpointable() a.l = {"k": [numpy.zeros([2, 2])]} self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}), nest.flatten(a.l)) self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l) nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]}) a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]} self.assertAllClose( {"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]}, self.evaluate(a.tensors))
def _get_checkpoint_name(self, name): root = tracking.AutoCheckpointable() checkpointable_utils.add_variable(root, name=name, shape=[1, 2], dtype=dtypes.float64) (named_variable, ), _, _ = checkpointable_utils._serialize_object_graph( root, saveables_cache=None) with ops.name_scope("root/" + named_variable.name): pass # Make sure we can use this as an op name if we prefix it. return named_variable.name
def test_explicit_save_signature(self): @def_function.function def func(x): return 2 * x root = tracking.AutoCheckpointable() root.f = func imported = self.cycle( root, {"f": root.f.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.float32))}) self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy())
def test_nested_functions(self): f = def_function.function( lambda x: x * 2.0, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) g = def_function.function( lambda x: f(x) + 1.0, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root = tracking.AutoCheckpointable() root.g = g imported = self.cycle(root) imported.g(constant_op.constant([1.0]))
def test_nested_inputs(self): root = tracking.AutoCheckpointable() root.f = def_function.function( lambda x: 2. * x[0], input_signature=([tensor_spec.TensorSpec(None, dtypes.float32), tensor_spec.TensorSpec(None, dtypes.float32)],)) root.f([constant_op.constant(1.), constant_op.constant(1.)]) # Concrete functions must always have uniquely named Tensor inputs. Save # relies on this. with self.assertRaisesRegexp( ValueError, "two arguments named 'x'"): root.f.get_concrete_function()
def testNestedLists(self): a = tracking.AutoCheckpointable() a.l = [] b = tracking.AutoCheckpointable() a.l.append([b]) c = tracking.AutoCheckpointable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = tracking.AutoCheckpointable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = tracking.AutoCheckpointable() f = tracking.AutoCheckpointable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegexp(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testHashing(self): has_mappings = set([data_structures.Mapping(), data_structures.Mapping()]) self.assertEqual(2, len(has_mappings)) self.assertNotIn(data_structures.Mapping(), has_mappings) # In contrast to Mapping, dict wrappers are not hashable a = tracking.AutoCheckpointable() a.d = {} self.assertEqual({}, a.d) self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison self.assertNotEqual({1: 2}, a.d) with self.assertRaisesRegexp(TypeError, "unhashable"): set([a.d])
def testLayerCollectionWithExternalMutation(self): d = {} root = tracking.AutoCheckpointable() root.wrapper = d self.assertEqual([], root.wrapper.layers) self.assertEqual([], root.wrapper.trainable_weights) layer1 = core.Dense(1) layer2 = core.Dense(1) d["a"] = layer1 d["b"] = layer2 self.assertEqual([layer1, layer2], root.wrapper.layers) # The layers have still not created variables self.assertEqual([], root.wrapper.trainable_weights)
def test_asset_path_returned(self): root = tracking.AutoCheckpointable() root.path = tracking.TrackableAsset(self._vocab_path) save_dir = os.path.join(self.get_temp_dir(), "saved_model") root.get_asset = def_function.function(lambda: root.path.asset_path) save.save(root, save_dir, signatures=root.get_asset.get_concrete_function()) second_dir = os.path.join(self.get_temp_dir(), "second_dir") file_io.rename(save_dir, second_dir) imported_path = _import_and_infer(second_dir, {})["output_0"] self.assertIn(compat.as_str_any(second_dir), compat.as_str_any(imported_path))
def test_load_in_func_graph(self, cycles): root = tracking.AutoCheckpointable() root.v1 = variables.Variable(1.) root.v2 = variables.Variable(2.) root.f = def_function.function( lambda x: root.v2 * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) if cycles > 1: root = self.cycle(root, cycles - 1) path = tempfile.mkdtemp(prefix=self.get_temp_dir()) save.save(root, path) closure = tracking.AutoCheckpointable() @def_function.function def func(x): if not hasattr(closure, "model"): closure.model = load.load(path) return closure.model.f(x) inputs = constant_op.constant(2.) self.assertEqual(4.0, func(inputs).numpy())
def test_method_save_concrete(self): root = tracking.AutoCheckpointable() root.f = def_function.function(lambda z: {"out": 2. * z}) root.f(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save( root, save_dir, { "non_default_key": root.f.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.float32)) }) self.assertEqual({"out": 2.}, _import_and_infer(save_dir, {"z": 1.}, signature_key="non_default_key"))
def test_dict(self, cycles): root = tracking.AutoCheckpointable() root.variables = dict(a=variables.Variable(1.)) root.variables["b"] = variables.Variable(2.) root.variables["c"] = 1 root.funcs = dict( a=def_function.function(lambda: constant_op.constant(100.))) root.funcs["conc"] = root.funcs["a"].get_concrete_function() imported = self.cycle(root, cycles) self.assertEqual(1., imported.variables["a"].numpy()) self.assertEqual(2., imported.variables["b"].numpy()) self.assertEqual(set(["a", "b"]), set(imported.variables.keys())) self.assertEqual(100., imported.funcs["a"]().numpy()) self.assertEqual(100., imported.funcs["conc"]().numpy())
def test_nested_functions(self, cycles): f = def_function.function( lambda x: x * 2.0, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) g = def_function.function( lambda x: f(x) + 1.0, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root = tracking.AutoCheckpointable() root.g = g # TODO(vbardiovsky): Enable this test. For this to work, we must ensure that # concrete_function._inference_function._graph._functions contains all # functions that were on the graph before saving. imported = self.cycle(root, 1) imported.g(constant_op.constant([1.0]))
def test_revived_concrete_function_tensorspec_kwargs(self, cycles): @def_function.function def func(*args): x, y = args return x * (y + 1.) root = tracking.AutoCheckpointable() root.f = func.get_concrete_function( tensor_spec.TensorSpec([], dtypes.float32, name="x"), tensor_spec.TensorSpec([], dtypes.float32, name="y")) self.assertEqual(8., root.f(y=constant_op.constant(3.), x=constant_op.constant(2.)).numpy()) imported = self.cycle(root, cycles, signatures={}) self.assertEqual(8., imported.f(y=constant_op.constant(3.), x=constant_op.constant(2.)).numpy())
def test_revived_concrete_function_kwargs(self, cycles): @def_function.function def func(x, y): return x * (y + 1.) root = tracking.AutoCheckpointable() root.f = func.get_concrete_function( tensor_spec.TensorSpec([], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32)) self.assertEqual(8., root.f(y=constant_op.constant(3.), x=constant_op.constant(2.)).numpy()) # TODO(andresp): Fix exporting of loaded concrete functions as signatures. imported = self.cycle(root, cycles, signatures={}) self.assertEqual(8., imported.f(y=constant_op.constant(3.), x=constant_op.constant(2.)).numpy())
def test_function_and_component(self, cycles): @def_function.function def func(v): return v + 1 root = tracking.AutoCheckpointable() root.func = func root.concrete_func = func.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.int32)) one = constant_op.constant(1) self.assertEqual(2, root.func(one).numpy()) self.assertEqual(2, root.concrete_func(one).numpy()) imported = self.cycle(root, cycles) self.assertEqual(2, imported.func(one).numpy()) self.assertEqual(2, imported.concrete_func(one).numpy())
def testListWrapperBasic(self): # _ListWrapper, unlike List, compares like the built-in list type (since it # is used to automatically replace lists). a = tracking.AutoCheckpointable() b = tracking.AutoCheckpointable() self.assertEqual([a, a], [a, a]) self.assertEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([a, a])) self.assertEqual([a, a], data_structures._ListWrapper([a, a])) self.assertEqual(data_structures._ListWrapper([a, a]), [a, a]) self.assertNotEqual([a, a], [b, a]) self.assertNotEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([b, a])) self.assertNotEqual([a, a], data_structures._ListWrapper([b, a])) self.assertLess([a], [a, b]) self.assertLess(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertLessEqual([a], [a, b]) self.assertLessEqual(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertGreater([a, b], [a]) self.assertGreater(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertGreaterEqual([a, b], [a]) self.assertGreaterEqual(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertEqual([a], data_structures._ListWrapper([a])) self.assertEqual([a], list(data_structures.List([a]))) self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) self.assertIsInstance(data_structures._ListWrapper([a]), list)
def test_implicit_input_signature(self): @def_function.function def func(x): return 2 * x root = tracking.AutoCheckpointable() root.f = func # Add two traces. root.f(constant_op.constant(1.)) root.f(constant_op.constant(1)) imported = self.cycle(root) self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy()) self.assertEqual(14, imported.f(constant_op.constant(7)).numpy())
def testDictDeepCopy(self): root = tracking.AutoCheckpointable() orig_dict = {"a": [1.]} root.a = orig_dict copied = copy.deepcopy(root.a) self.assertAllEqual([1.], copied["a"]) self.assertIsNot(root.a, copied) self.assertIsNot(root.a["a"], copied["a"]) # Dirtiness should be inherited util.list_objects(root.a) orig_dict["b"] = [] with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def testListDeepCopy(self): root = tracking.AutoCheckpointable() orig_list = [[1.]] root.a = orig_list copied = copy.deepcopy(root.a) self.assertAllEqual([[1.]], copied) self.assertIsNot(root.a, copied) self.assertIsNot(root.a[0], copied[0]) # Dirtiness should be inherited util.list_objects(root.a) orig_list.append(1.) with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def test_concrete_function_arg_names(self): @def_function.function( input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) def func(x): return 2 * x root = tracking.AutoCheckpointable() root.f = func.get_concrete_function() self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) imported = self.cycle(root) self.assertAllEqual([2, 4, 6], imported.f(x=constant_op.constant([1, 2, 3])).numpy())