def __init__(self): super().__init__() self.isdep = keras.layers.Dense(1) self.notdep = data_structures.NoDependency( keras.layers.Dense(2) ) self.notdep_var = data_structures.NoDependency( tf.Variable(1.0, name="notdep_var") )
def __init__( self, handle, trainable=False, arguments=None, _sentinel=None, # pylint: disable=invalid-name tags=None, signature=None, signature_outputs_as_dict=None, output_key=None, output_shape=None, load_options=None, **kwargs): # Note: for compatibility with keras-model serialization this layer is # json-serializable. If you add or change arguments here, please also update # the `get_config` method. # The arguments are marked NoDependency to avoid autoconversion to a # trackable _DictWrapper, because that upsets json.dumps() when saving # the result of get_config(). self._handle = handle self._arguments = data_structures.NoDependency(arguments or {}) self._signature = signature self._signature_outputs_as_dict = signature_outputs_as_dict self._output_key = output_key if output_shape: # Autograph chokes on _convert_nest_to_shapes(), so we call it here # and not from within call(). self._output_shape = data_structures.NoDependency( _convert_nest_to_shapes(output_shape)) self._load_options = load_options self._func = load_module(handle, tags, self._load_options) self._is_hub_module_v1 = getattr(self._func, "_is_hub_module_v1", False) # Update with the defaults when using legacy TF1 Hub format. if self._is_hub_module_v1: self._signature = self._signature or "default" if not self._signature_outputs_as_dict: self._output_key = self._output_key or "default" # More validity checks. if self._signature and (bool(self._output_key is not None) == bool(self._signature_outputs_as_dict)): raise ValueError("When using a signature, either output_key or " "signature_outputs_as_dict=True should be set.") if not self._signature and self._signature_outputs_as_dict: raise ValueError("signature_outputs_as_dict is only valid if specifying " "a signature (or using a legacy TF1 Hub format).") self._callable = self._get_callable() self._has_training_argument = func_has_training_argument(self._callable) self._setup_layer(trainable, **kwargs)
def __init__(self, mesh: layout.Mesh, root=None, **kwargs): super(DTensorCheckpoint, self).__init__(root=root, **kwargs) self._mesh = mesh saver_root = self attached_dependencies = None self._save_counter = None # Created lazily for restore-on-create. self._save_assign_op = None if root: util._assert_trackable(root, "root") saver_root = root attached_dependencies = [] # All keyword arguments (including root itself) are set as children # of root. kwargs["root"] = root root._maybe_initialize_trackable() self._save_counter = data_structures.NoDependency( root._lookup_dependency("save_counter")) self._root = data_structures.NoDependency(root) for k, v in sorted(kwargs.items(), key=lambda item: item[0]): setattr(self, k, v) # Call getattr instead of directly using v because setattr converts # v to a Trackable data structure when v is a list/dict/tuple. converted_v = getattr(self, k) util._assert_trackable(converted_v, k) if root: # Make sure that root doesn't already have dependencies with these names attached_dependencies = attached_dependencies or [] child = root._lookup_dependency(k) if child is None: attached_dependencies.append( base.TrackableReference(k, converted_v)) elif child != converted_v: raise ValueError( "Cannot create a Checkpoint with keyword argument {name} if " "root.{name} already exists.".format(name=k)) # DTensor Change: # Override the parents saver with DTrackableSaver with _SingleDeviceSaver. self._saver = DTrackableSaver( mesh, graph_view_lib.ObjectGraphView( weakref.ref(saver_root), attached_dependencies=attached_dependencies))
def testNestedLists(self): a = autotrackable.AutoTrackable() a.l = [] b = autotrackable.AutoTrackable() a.l.append([b]) c = autotrackable.AutoTrackable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = autotrackable.AutoTrackable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = autotrackable.AutoTrackable() f = autotrackable.AutoTrackable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegex(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testDictWrapperNoDependency(self): a = tf.Module() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNoDepList(self): a = training.Model() a.l1 = data_structures.NoDependency([]) a.l1.insert(1, 0) self.assertIsInstance(a.l1, list) checkpoint = tf.train.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l2 = [] a.l2.insert(1, tf.Module()) with self.assertRaisesRegex(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testNonAppendNotTrackable(self): # Non-append mutations (deleting or overwriting values) are OK when the # values aren't tracked. a = tf.Module() a.d = {} a.d["a"] = [3] a.d[1] = 3 a.d[1] = 2 self.assertEqual(2, a.d[1]) del a.d[1] a.d[2] = data_structures.NoDependency(tf.Module()) second = tf.Module() a.d[2] = data_structures.NoDependency(second) self.assertIs(second, a.d[2]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNonStringKeyNotTrackableValue(self): a = tf.Module() a.d = {} a.d["a"] = [3] a.d[1] = data_structures.NoDependency([3]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNoDependency(self): root = tf.Module() hasdep = tf.Module() root.hasdep = hasdep nodep = tf.Module() root.nodep = data_structures.NoDependency(nodep) self.assertLen(root._trackable_children(), 1) self.assertIs(root._trackable_children()["hasdep"], root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @tf.__internal__.tracking.no_automatic_dependency_tracking def __init__(self): super().__init__() self.a = [] self.b = tf.Module() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def _no_dependency(self, value): """Override to allow TrackableBase to disable dependency tracking.""" return data_structures.NoDependency(value)