def test_cycles_with_path(self): mod = module.Module() mod.w = variables.Variable(1.) mod.encoder = module.Module() mod.encoder.w = [({"k": mod.w}, {"k": mod.w})] mod.decoder = mod.encoder # This introduces two cycles: on mod.encoder.mod and mod.decoder.mod. mod.decoder.mod = mod state_dict = dict( mod._flatten(with_path=True, predicate=module._is_variable)) self.assertEqual( state_dict, { ("w", ): mod.w, ("encoder", "mod", "w"): mod.encoder.mod.w, ("decoder", "mod", "w"): mod.decoder.mod.w, ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"], ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"], ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"], ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"] }, )
def _testConvertedFunction(self, obj, func, converted_concrete_func, input_data): # Ensure the converted graph has no variables and no function calls. constant_graph_def = converted_concrete_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(constant_graph_def)) self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def)) # Check that the converted ConcreteFunction produces the same result as the # original Function. expected_value = nest.flatten(func(**input_data)) actual_value = nest.flatten(converted_concrete_func(**input_data)) for expected, actual in zip(expected_value, actual_value): np.testing.assert_almost_equal(expected.numpy(), actual.numpy()) # Ensure the shape is retained. for tensor in converted_concrete_func.inputs: actual_shape = input_data[tensor.name.split(":")[0]].shape self.assertEqual(tensor.shape, actual_shape) # Save the converted ConcreteFunction as a signature. save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model") root = module.Module() root.f = converted_concrete_func save(root, save_dir, {"mykey": converted_concrete_func}) # Load it back and make sure it works. loaded_obj = load(save_dir) actual_value = nest.flatten(loaded_obj.signatures["mykey"](**input_data)) for expected, actual in zip(expected_value, actual_value): np.testing.assert_almost_equal(expected.numpy(), actual.numpy())
def test_model_wrapped_in_module_discovers_submodules(self): linear = models.Sequential([layers.Dense(units=1, input_shape=[1])]) linear.compile(optimizer="sgd", loss="mean_squared_error") m = module.Module() m.l = linear self.assertNotEmpty(m.submodules) self.assertLen(m.variables, 2)
def testLoopAssignedModule(self): m = module.Module() m.s = (m, ) self.assertLen(m._trackable_children(), 1) self.assertIn("s", m._trackable_children()) self.assertIs(m.s, m._trackable_children()["s"]) self.assertEqual((), m.trainable_variables)
def testLoopAssignedModule(self): m = module.Module() m.s = (m, ) self.assertLen(m._checkpoint_dependencies, 1) self.assertIs(m.s, m._checkpoint_dependencies[0].ref) self.assertIs("s", m._checkpoint_dependencies[0].name) self.assertEqual((), m.trainable_variables)
def testNamedTupleConflictingAttributes(self): named = collections.namedtuple("Named", ("x", "weights")) v = variables.Variable(2) nt = named(x=v, weights=3) m = module.Module() m.nt = nt self.assertEqual(3, m.nt.weights)
def test_saved_model(self): different_values = self.device.pack( [constant_op.constant(-1.), constant_op.constant(3.)]) with self.device: m = module.Module() m.v = variables.Variable(different_values) m.f = def_function.function(lambda: m.v * 2.) self.assertAllClose([-2., 6.], self.device.unpack(m.f())) saved_model_path = os.path.join(self.get_temp_dir(), "saved_model") save.save(m, saved_model_path) context._reset_context() self.setUp() single_device_loaded = load.load(saved_model_path) self.assertAllClose(-2., single_device_loaded.f()) assign_value = self.device.pack( [constant_op.constant(.1), constant_op.constant(.2)]) with self.device: parallel_loaded = load.load(saved_model_path) self.assertAllClose([-2., 6.], self.device.unpack(parallel_loaded.f())) self.assertAllClose([-1., 3.], self.device.unpack(parallel_loaded.v)) parallel_loaded.v.assign(assign_value) self.assertAllClose([.2, .4], self.device.unpack(parallel_loaded.f()))
def test_composite_variable(self): class Spec(type_spec.TypeSpec): value_type = property(lambda self: CompositeVariable) def _component_specs(self): pass def _serialize(self): pass def _to_components(self, value): return value._variables def _from_components(self, variable_list): return CompositeVariable(variable_list) class CompositeVariable(composite_tensor.CompositeTensor): def __init__(self, variable_list): self._variables = variable_list @property def _type_spec(self): return Spec() m = module.Module() m.a = CompositeVariable( [variables.Variable(1.), variables.Variable(2.)]) self.assertAllEqual(m.variables, m.a._variables)
def test_memoized_in_tf2(self): if not tf2.enabled(): self.skipTest("Requires TF2") mod = module.Module(name="name") name_scope_1 = mod.name_scope name_scope_2 = mod.name_scope self.assertIs(name_scope_1, name_scope_2)
def test_raises_error_with_path(self): non_orderable = object m = module.Module() m.layers = {non_orderable(): None, non_orderable(): None} with self.assertRaisesRegex(ValueError, "Error processing property 'layers'"): m.variables # pylint: disable=pointless-statement
def test_supports_variable_like_objects(self): m = module.Module() v = VariableLike() self.assertFalse(hasattr(v, "trainable")) m.v = v self.assertEqual(m.variables, (v, )) self.assertEmpty(m.trainable_variables) m.v.trainable = True self.assertEqual(m.trainable_variables, (v, ))
def testDictWrapperBadKeys(self): a = module.Module() a.d = {} a.d[1] = data_structures.List() model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") with self.assertRaisesRegex(ValueError, "non-string key"): model.save_weights(save_path)
def test_not_memoized_in_tf1(self): if tf2.enabled(): self.skipTest("Requires TF1") mod = module.Module(name="name") name_scope_1 = mod.name_scope name_scope_2 = mod.name_scope self.assertIsNot(name_scope_1, name_scope_2) self.assertEqual(name_scope_1.name, name_scope_2.name)
def testDictWrapperNoDependency(self): a = module.Module() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def test_untracked_variable_useful_message(self): root = module.Module() v = variables.Variable(1., name="some_unique_name") @def_function.function(input_signature=[]) def f(): return v.read_value() root.f = f with self.assertRaisesRegex(AssertionError, "some_unique_name"): save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
def testNamedTuple(self): named = collections.namedtuple("Named", ("x", "y")) v = variables.Variable(2) nt = named(x=v, y=2) m = module.Module() m.nt = nt self.assertIs(v, m.nt.x) self.assertIs(v, m.nt[0]) self.assertIs(v, m._trackable_children()["nt"]._trackable_children()["x"]) self.assertEqual(2, m.nt.y)
def testNamedTuple(self): named = collections.namedtuple("Named", ("x", "y")) v = variables.Variable(2) nt = named(x=v, y=2) m = module.Module() m.nt = nt self.assertIs(v, m.nt.x) self.assertIs(v, m.nt[0]) self.assertIs( v, m._checkpoint_dependencies[0].ref._checkpoint_dependencies[0].ref) self.assertEqual(2, m.nt.y)
def _get_checkpoint_name(self, name): root = module.Module() trackable_utils.add_variable(root, name=name, shape=[1, 2], dtype=dtypes.float64) (named_variable, ), _, _ = trackable_utils._serialize_object_graph( root, saveables_cache=None) with ops.name_scope_v2("root/" + named_variable.name): pass # Make sure we can use this as an op name if we prefix it. return named_variable.name
def test_with_path(self): mod = module.Module() mod.w = variables.Variable(1.) mod.encoder = module.Module() mod.encoder.w = [({"k": mod.w}, {"k": mod.w})] mod.decoder = mod.encoder state_dict = dict( mod._flatten(with_path=True, predicate=module._IS_VARIABLE)) self.assertEqual( state_dict, { ("w", ): mod.w, ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"], ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"], ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"], ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"] }, )
def testNoDepList(self): a = training.Model() a.l1 = data_structures.NoDependency([]) a.l1.insert(1, 0) self.assertIsInstance(a.l1, list) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l2 = [] a.l2.insert(1, module.Module()) with self.assertRaisesRegex(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testNonAppendNotTrackable(self): # Non-append mutations (deleting or overwriting values) are OK when the # values aren't tracked. a = module.Module() a.d = {} a.d["a"] = [3] a.d[1] = 3 a.d[1] = 2 self.assertEqual(2, a.d[1]) del a.d[1] a.d[2] = data_structures.NoDependency(module.Module()) second = module.Module() a.d[2] = data_structures.NoDependency(second) self.assertIs(second, a.d[2]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNoDependency(self): root = module.Module() hasdep = module.Module() root.hasdep = hasdep nodep = module.Module() root.nodep = data_structures.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @base.no_automatic_dependency_tracking def __init__(self): super(NoDependencyModel, self).__init__() self.a = [] self.b = module.Module() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def testNonStringKeyNotTrackableValue(self): a = module.Module() a.d = {} a.d["a"] = [3] a.d[1] = data_structures.NoDependency([3]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testLoadSavedModelWithUnregisteredStruct(self): MaskedTensor = build_simple_masked_tensor_type() def f(x, y): x_values = x.values if isinstance(x, MaskedTensor) else x y_values = y.values if isinstance(y, MaskedTensor) else y x_mask = x.mask if isinstance(x, MaskedTensor) else True y_mask = y.mask if isinstance(y, MaskedTensor) else True return MaskedTensor(x_values + y_values, x_mask & y_mask) t_spec = tensor_spec.TensorSpec(None, dtypes.int32) b_spec = tensor_spec.TensorSpec(None, dtypes.bool) mt_spec = MaskedTensor.Spec(values=t_spec, mask=b_spec) model = module.Module() model.f = def_function.function(f) model.f.get_concrete_function(t_spec, t_spec) model.f.get_concrete_function(t_spec, mt_spec) model.f.get_concrete_function(mt_spec, t_spec) model.f.get_concrete_function(mt_spec, mt_spec) path = tempfile.mkdtemp(prefix=test.get_temp_dir()) with temporarily_register_type_spec('tf.test.MaskedTensor.Spec', MaskedTensor.Spec): save.save(model, path) loaded_model = load.load(path) with self.assertRaises(ValueError): type_spec.lookup('tf.test.MaskedTensor') t = constant_op.constant([10, 20, 30]) v1 = loaded_model.f(t, t) self.assertIsInstance(v1, tensor_struct.AnonymousStruct) self.assertAllEqual(v1.values, [20, 40, 60]) self.assertAllEqual(v1.mask, True) v2 = loaded_model.f(v1, v1) self.assertIsInstance(v2, tensor_struct.AnonymousStruct) self.assertAllEqual(v2.values, [40, 80, 120]) self.assertAllEqual(v2.mask, True) mt = MaskedTensor([1, 2, 3], [True, True, False]) v3 = loaded_model.f( t, tensor_struct.reinterpret_struct(mt, tensor_struct.AnonymousStruct)) self.assertIsInstance(v3, tensor_struct.AnonymousStruct) self.assertAllEqual(v3.values, [11, 22, 33]) self.assertAllEqual(v3.mask, [True, True, False]) v4 = tensor_struct.reinterpret_struct(v3, MaskedTensor) self.assertIsInstance(v4, MaskedTensor) self.assertAllEqual(v4.values, [11, 22, 33]) self.assertAllEqual(v4.mask, [True, True, False])
def test_supports_distributed_variables(self): mirrored = distributed_values.MirroredVariable( None, [variables.Variable(1.)], variables.VariableAggregation.SUM) tpu = tpu_values.TPUMirroredVariable( strategy=None, values=[variables.Variable(42.)], aggregation=None) aggregating = ps_values.AggregatingVariable( strategy=None, v=variables.Variable(1.), aggregation=None) m = module.Module() m.a = mirrored m.b = tpu m.c = aggregating self.assertEqual(m.variables, (mirrored, tpu, aggregating))
def testLayerCollectionWithExternalMutation(self): d = {} root = module.Module() root.wrapper = d self.assertEqual([], root.wrapper.layers) self.assertEqual([], root.wrapper.trainable_weights) layer1 = core.Dense(1) layer2 = core.Dense(1) d["a"] = layer1 d["b"] = layer2 self.assertEqual([layer1, layer2], root.wrapper.layers) # The layers have still not created variables self.assertEqual([], root.wrapper.trainable_weights)
def testSameStructure(self): t = (variables.Variable(1.), ) m = module.Module() m.t = t nest.assert_same_structure(t, m.t) nest.assert_same_structure(m.t, t) nt_type = collections.namedtuple("nt", ["x", "y"]) nt = nt_type(x=1, y=2) m.nt = nt nest.assert_same_structure(m.nt, nt) with self.assertRaises(TypeError): # pylint: disable=g-error-prone-assert-raises nest.assert_same_structure(m.nt, m.t)
def testNamedtupleSubclassWithCustomNew(self): class SubclassWithDifferentArgs(collections.namedtuple("A", ["x"])): def __new__(cls): return super(SubclassWithDifferentArgs, cls).__new__(cls, []) nt = SubclassWithDifferentArgs() m = module.Module() m.nt = nt m.nt.x.append(variables.Variable(1.)) prefix = os.path.join(self.get_temp_dir(), "ckpt") ckpt = util.Checkpoint(m=m) with self.assertRaises(ValueError): ckpt.save(prefix)
def test_raises_error_with_path(self): if six.PY2: class NonOrderable(object): __lt__ = None non_orderable = NonOrderable else: non_orderable = object m = module.Module() m.layers = {non_orderable(): None, non_orderable(): None} with self.assertRaisesRegexp(ValueError, "Error processing property 'layers'"): m.variables # pylint: disable=pointless-statement
def test_attributes_to_ignore(self): class DangerousModule(module.Module): _TF_MODULE_IGNORED_PROPERTIES = frozenset( itertools.chain(("dangerous_submodule", "dangerous_variable"), module.Module._TF_MODULE_IGNORED_PROPERTIES)) mod = DangerousModule() mod.dangerous_submodule = module.Module() mod.dangerous_variable = variables.Variable(1.) mod.normal_variable = variables.Variable(2.) self.assertEmpty(mod.submodules) self.assertLen(mod.variables, 1) self.assertEqual(mod.variables[0], mod.normal_variable)